diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 112596057dd9..3d5fd84ad949 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -40,7 +40,7 @@ jobs: framework: pytorch_examples runner: docker-cpu image: diffusers/diffusers-pytorch-cpu - report: torch_cpu + report: torch_example_cpu name: ${{ matrix.config.name }} diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index bf830959cf01..525df28cbaa8 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -38,7 +38,7 @@ jobs: framework: pytorch_examples runner: docker-cpu image: diffusers/diffusers-pytorch-cpu - report: torch_cpu + report: torch_example_cpu name: ${{ matrix.config.name }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e0e873892ca2..5ce48793e9c2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -394,8 +394,15 @@ passes. You should run the tests impacted by your changes like this: ```bash $ pytest tests/.py ``` + +Before you run the tests, please make sure you install the dependencies required for testing. You can do so +with this command: -You can also run the full suite with the following command, but it takes + ```bash + $ pip install -e ".[test]" + ``` + +You can run the full test suite with the following command, but it takes a beefy machine to produce a result in a decent amount of time now that Diffusers has grown a lot. Here is the command for it: @@ -439,7 +446,7 @@ Push the changes to your account using: $ git push -u origin a-descriptive-name-for-my-changes ``` -6. Once you are satisfied (**and the checklist below is happy too**), go to the +6. Once you are satisfied, go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3ed5ad159982..d74bd3785343 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -4,7 +4,7 @@ - local: quicktour title: Quicktour - local: stable_diffusion - title: Stable Diffusion + title: Effective and efficient diffusion - local: installation title: Installation title: Get started @@ -33,15 +33,15 @@ - local: using-diffusers/pipeline_overview title: Overview - local: using-diffusers/unconditional_image_generation - title: Unconditional Image Generation + title: Unconditional image generation - local: using-diffusers/conditional_image_generation - title: Text-to-Image Generation + title: Text-to-image generation - local: using-diffusers/img2img - title: Text-Guided Image-to-Image + title: Text-guided image-to-image - local: using-diffusers/inpaint - title: Text-Guided Image-Inpainting + title: Text-guided image-inpainting - local: using-diffusers/depth2img - title: Text-Guided Depth-to-Image + title: Text-guided depth-to-image - local: using-diffusers/reusing_seeds title: Improve image quality with deterministic generation - local: using-diffusers/reproducibility @@ -52,6 +52,8 @@ title: How to contribute a Pipeline - local: using-diffusers/using_safetensors title: Using safetensors + - local: using-diffusers/stable_diffusion_jax_how_to + title: Stable Diffusion in JAX/Flax - local: using-diffusers/weighted_prompts title: Weighting Prompts title: Pipelines for Inference @@ -95,6 +97,8 @@ title: ONNX - local: optimization/open_vino title: OpenVINO + - local: optimization/coreml + title: Core ML - local: optimization/mps title: MPS - local: optimization/habana @@ -134,6 +138,8 @@ title: AltDiffusion - local: api/pipelines/audio_diffusion title: Audio Diffusion + - local: api/pipelines/audioldm + title: AudioLDM - local: api/pipelines/cycle_diffusion title: Cycle Diffusion - local: api/pipelines/dance_diffusion @@ -158,6 +164,8 @@ title: Score SDE VE - local: api/pipelines/semantic_stable_diffusion title: Semantic Guidance + - local: api/pipelines/spectrogram_diffusion + title: "Spectrogram Diffusion" - sections: - local: api/pipelines/stable_diffusion/overview title: Overview @@ -187,6 +195,8 @@ title: MultiDiffusion Panorama - local: api/pipelines/stable_diffusion/controlnet title: Text-to-Image Generation with ControlNet Conditioning + - local: api/pipelines/stable_diffusion/model_editing + title: Text-to-Image Model Editing title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 @@ -196,6 +206,8 @@ title: Stochastic Karras VE - local: api/pipelines/text_to_video title: Text-to-Video + - local: api/pipelines/text_to_video_zero + title: Text-to-Video Zero - local: api/pipelines/unclip title: UnCLIP - local: api/pipelines/latent_diffusion_uncond diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 1d55bd03c064..8cbf21b8e0cf 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -28,3 +28,11 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### UNet2DConditionLoadersMixin [[autodoc]] loaders.UNet2DConditionLoadersMixin + +### TextualInversionLoaderMixin + +[[autodoc]] loaders.TextualInversionLoaderMixin + +### LoraLoaderMixin + +[[autodoc]] loaders.LoraLoaderMixin diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index 572f8873ba12..2361fd4f6597 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## FlaxAutoencoderKL [[autodoc]] FlaxAutoencoderKL + +## FlaxControlNetOutput +[[autodoc]] models.controlnet_flax.FlaxControlNetOutput + +## FlaxControlNetModel +[[autodoc]] FlaxControlNetModel diff --git a/docs/source/en/api/pipelines/alt_diffusion.mdx b/docs/source/en/api/pipelines/alt_diffusion.mdx index cb86208ddbe1..8463fd51ddbb 100644 --- a/docs/source/en/api/pipelines/alt_diffusion.mdx +++ b/docs/source/en/api/pipelines/alt_diffusion.mdx @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # AltDiffusion -AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu +AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu. The abstract of the paper is the following: @@ -28,11 +28,11 @@ The abstract of the paper is the following: ## Tips -- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion/overview). +- AltDiffusion is conceptually exactly the same as [Stable Diffusion](./stable_diffusion/overview). - *Run AltDiffusion* -AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img). +AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](../../using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](../../using-diffusers/img2img). - *How to load and use different schedulers.* diff --git a/docs/source/en/api/pipelines/audioldm.mdx b/docs/source/en/api/pipelines/audioldm.mdx new file mode 100644 index 000000000000..f3987d2263ac --- /dev/null +++ b/docs/source/en/api/pipelines/audioldm.mdx @@ -0,0 +1,82 @@ + + +# AudioLDM + +## Overview + +AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://arxiv.org/abs/2301.12503) by Haohe Liu et al. + +Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM +is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap) +latents. AudioLDM takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional +sound effects, human speech and music. + +This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found [here](https://github.com/haoheliu/AudioLDM). + +## Text-to-Audio + +The [`AudioLDMPipeline`] can be used to load pre-trained weights from [cvssp/audioldm](https://huggingface.co/cvssp/audioldm) and generate text-conditional audio outputs: + +```python +from diffusers import AudioLDMPipeline +import torch +import scipy + +repo_id = "cvssp/audioldm" +pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" +audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + +# save the audio sample as a .wav file +scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) +``` + +### Tips + +Prompts: +* Descriptive prompt inputs work best: you can use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g., "water stream in a forest" instead of "stream"). +* It's best to use general terms like 'cat' or 'dog' instead of specific names or abstract objects that the model may not be familiar with. + +Inference: +* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument: higher steps give higher quality audio at the expense of slower inference. +* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument. + +### How to load and use different schedulers + +The AudioLDM pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers +that can be used with the AudioLDM pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], +[`EulerAncestralDiscreteScheduler`] etc. We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest +scheduler there is. + +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] +method, or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the +[`DPMSolverMultistepScheduler`], you can do the following: + +```python +>>> from diffusers import AudioLDMPipeline, DPMSolverMultistepScheduler +>>> import torch + +>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16) +>>> pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained("cvssp/audioldm", subfolder="scheduler") +>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", scheduler=dpm_scheduler, torch_dtype=torch.float16) +``` + +## AudioLDMPipeline +[[autodoc]] AudioLDMPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 3bf29888ae54..3c5331955513 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -19,9 +19,9 @@ components - all of which are needed to have a functioning end-to-end diffusion As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: - [Autoencoder](./api/models#vae) - [Conditional Unet](./api/models#UNet2DConditionModel) -- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPTextModel) +- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.27.1/en/model_doc/clip#transformers.CLIPTextModel) - a scheduler component, [scheduler](./api/scheduler#pndm), -- a [CLIPFeatureExtractor](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPFeatureExtractor), +- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/v4.27.1/en/model_doc/clip#transformers.CLIPImageProcessor), - as well as a [safety checker](./stable_diffusion#safety_checker). All of these components are necessary to run stable diffusion in inference even though they were trained or created independently from each other. @@ -83,6 +83,7 @@ available a colab notebook to directly try them out. | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | | [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | | [vq_diffusion](./vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | +| [text_to_video_zero](./text_to_video_zero) | [Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) | Text-to-Video Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. @@ -108,7 +109,7 @@ from the local path. each pipeline, one should look directly into the respective pipeline. **Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should -not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community) +not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). ## Contribution @@ -173,7 +174,7 @@ You can also run this example on colab [![Open In Colab](https://colab.research. ### Tweak prompts reusing seeds and latents -You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). +You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ### In-painting using Stable Diffusion diff --git a/docs/source/en/api/pipelines/paint_by_example.mdx b/docs/source/en/api/pipelines/paint_by_example.mdx index 04390a14b758..5abb3406db44 100644 --- a/docs/source/en/api/pipelines/paint_by_example.mdx +++ b/docs/source/en/api/pipelines/paint_by_example.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. ## Overview -[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen +[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen. The abstract of the paper is the following: diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx index f1b2cc3892dd..b4562cf0c389 100644 --- a/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx +++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.mdx @@ -24,11 +24,11 @@ The abstract of the paper is the following: | Pipeline | Tasks | Colab | Demo |---|---|:---:|:---:| -| [pipeline_semantic_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) | [Coming Soon](https://huggingface.co/AIML-TUDA) +| [pipeline_semantic_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) | [Coming Soon](https://huggingface.co/AIML-TUDA) ## Tips -- The Semantic Guidance pipeline can be used with any [Stable Diffusion](./api/pipelines/stable_diffusion/text2img) checkpoint. +- The Semantic Guidance pipeline can be used with any [Stable Diffusion](./stable_diffusion/text2img) checkpoint. ### Run Semantic Guidance @@ -67,7 +67,7 @@ out = pipe( ) ``` -For more examples check the colab notebook. +For more examples check the Colab notebook. ## StableDiffusionSafePipelineOutput [[autodoc]] pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/spectrogram_diffusion.mdx b/docs/source/en/api/pipelines/spectrogram_diffusion.mdx new file mode 100644 index 000000000000..728c6b3aa2f2 --- /dev/null +++ b/docs/source/en/api/pipelines/spectrogram_diffusion.mdx @@ -0,0 +1,54 @@ + + +# Multi-instrument Music Synthesis with Spectrogram Diffusion + +## Overview + +[Spectrogram Diffusion](https://arxiv.org/abs/2206.05408) by Curtis Hawthorne, Ian Simon, Adam Roberts, Neil Zeghidour, Josh Gardner, Ethan Manilow, and Jesse Engel. + +An ideal music synthesizer should be both interactive and expressive, generating high-fidelity audio in realtime for arbitrary combinations of instruments and notes. Recent neural synthesizers have exhibited a tradeoff between domain-specific models that offer detailed control of only specific instruments, or raw waveform models that can train on any music but with minimal control and slow generation. In this work, we focus on a middle ground of neural synthesizers that can generate audio from MIDI sequences with arbitrary combinations of instruments in realtime. This enables training on a wide range of transcription datasets with a single model, which in turn offers note-level control of composition and instrumentation across a wide range of instruments. We use a simple two-stage process: MIDI to spectrograms with an encoder-decoder Transformer, then spectrograms to audio with a generative adversarial network (GAN) spectrogram inverter. We compare training the decoder as an autoregressive model and as a Denoising Diffusion Probabilistic Model (DDPM) and find that the DDPM approach is superior both qualitatively and as measured by audio reconstruction and FrΓ©chet distance metrics. Given the interactivity and generality of this approach, we find this to be a promising first step towards interactive and expressive neural synthesis for arbitrary combinations of instruments and notes. + +The original codebase of this implementation can be found at [magenta/music-spectrogram-diffusion](https://github.com/magenta/music-spectrogram-diffusion). + +## Model + +![img](https://storage.googleapis.com/music-synthesis-with-spectrogram-diffusion/architecture.png) + +As depicted above the model takes as input a MIDI file and tokenizes it into a sequence of 5 second intervals. Each tokenized interval then together with positional encodings is passed through the Note Encoder and its representation is concatenated with the previous window's generated spectrogram representation obtained via the Context Encoder. For the initial 5 second window this is set to zero. The resulting context is then used as conditioning to sample the denoised Spectrogram from the MIDI window and we concatenate this spectrogram to the final output as well as use it for the context of the next MIDI window. The process repeats till we have gone over all the MIDI inputs. Finally a MelGAN decoder converts the potentially long spectrogram to audio which is the final result of this pipeline. + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_spectrogram_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py) | *Unconditional Audio Generation* | - | + + +## Example usage + +```python +from diffusers import SpectrogramDiffusionPipeline, MidiProcessor + +pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") +pipe = pipe.to("cuda") +processor = MidiProcessor() + +# Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid +output = pipe(processor("beethoven_hammerklavier_2.mid")) + +audio = output.audios[0] +``` + +## SpectrogramDiffusionPipeline +[[autodoc]] SpectrogramDiffusionPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index aafbf5b05d79..5a4cfa41ca43 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -131,7 +131,7 @@ This should take only around 3-4 seconds on GPU (depending on hardware). The out ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_disco_dancing.png) -**Note**: To see how to run all other ControlNet checkpoints, please have a look at [ControlNet with Stable Diffusion 1.5](#controlnet-with-stable-diffusion-1.5) +**Note**: To see how to run all other ControlNet checkpoints, please have a look at [ControlNet with Stable Diffusion 1.5](#controlnet-with-stable-diffusion-1.5). @@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h - disable_vae_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + +## FlaxStableDiffusionControlNetPipeline +[[autodoc]] FlaxStableDiffusionControlNetPipeline + - all + - __call__ + diff --git a/docs/source/en/api/pipelines/stable_diffusion/image_variation.mdx b/docs/source/en/api/pipelines/stable_diffusion/image_variation.mdx index 939732f4c274..8ca69ff69aec 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/image_variation.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/image_variation.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. ## StableDiffusionImageVariationPipeline -[`StableDiffusionImageVariationPipeline`] lets you generate variations from an input image using Stable Diffusion. It uses a fine-tuned version of Stable Diffusion model, trained by [Justin Pinkney](https://www.justinpinkney.com/) (@Buntworthy) at [Lambda](https://lambdalabs.com/) +[`StableDiffusionImageVariationPipeline`] lets you generate variations from an input image using Stable Diffusion. It uses a fine-tuned version of Stable Diffusion model, trained by [Justin Pinkney](https://www.justinpinkney.com/) (@Buntworthy) at [Lambda](https://lambdalabs.com/). The original codebase can be found here: [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) @@ -28,4 +28,4 @@ Available Checkpoints are: - enable_attention_slicing - disable_attention_slicing - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention \ No newline at end of file + - disable_xformers_memory_efficient_attention diff --git a/docs/source/en/api/pipelines/stable_diffusion/model_editing.mdx b/docs/source/en/api/pipelines/stable_diffusion/model_editing.mdx new file mode 100644 index 000000000000..7aae35ba2a91 --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/model_editing.mdx @@ -0,0 +1,61 @@ + + +# Editing Implicit Assumptions in Text-to-Image Diffusion Models + +## Overview + +[Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://arxiv.org/abs/2303.08084) by Hadas Orgad, Bahjat Kawar, and Yonatan Belinkov. + +The abstract of the paper is the following: + +*Text-to-image diffusion models often make implicit assumptions about the world when generating images. While some assumptions are useful (e.g., the sky is blue), they can also be outdated, incorrect, or reflective of social biases present in the training data. Thus, there is a need to control these assumptions without requiring explicit user input or costly re-training. In this work, we aim to edit a given implicit assumption in a pre-trained diffusion model. Our Text-to-Image Model Editing method, TIME for short, receives a pair of inputs: a "source" under-specified prompt for which the model makes an implicit assumption (e.g., "a pack of roses"), and a "destination" prompt that describes the same setting, but with a specified desired attribute (e.g., "a pack of blue roses"). TIME then updates the model's cross-attention layers, as these layers assign visual meaning to textual tokens. We edit the projection matrices in these layers such that the source prompt is projected close to the destination prompt. Our method is highly efficient, as it modifies a mere 2.2% of the model's parameters in under one second. To evaluate model editing approaches, we introduce TIMED (TIME Dataset), containing 147 source and destination prompt pairs from various domains. Our experiments (using Stable Diffusion) show that TIME is successful in model editing, generalizes well for related prompts unseen during editing, and imposes minimal effect on unrelated generations.* + +Resources: + +* [Project Page](https://time-diffusion.github.io/). +* [Paper](https://arxiv.org/abs/2303.08084). +* [Original Code](https://github.com/bahjat-kawar/time-diffusion). +* [Demo](https://huggingface.co/spaces/bahjat-kawar/time-diffusion). + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [StableDiffusionModelEditingPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py) | *Text-to-Image Model Editing* | [πŸ€— Space](https://huggingface.co/spaces/bahjat-kawar/time-diffusion)) | + +This pipeline enables editing the diffusion model weights, such that its assumptions on a given concept are changed. The resulting change is expected to take effect in all prompt generations pertaining to the edited concept. + +## Usage example + +```python +import torch +from diffusers import StableDiffusionModelEditingPipeline + +model_ckpt = "CompVis/stable-diffusion-v1-4" +pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + +pipe = pipe.to("cuda") + +source_prompt = "A pack of roses" +destination_prompt = "A pack of blue roses" +pipe.edit_model(source_prompt, destination_prompt) + +prompt = "A field of roses" +image = pipe(prompt).images[0] +image.save("field_of_roses.png") +``` + +## StableDiffusionModelEditingPipeline +[[autodoc]] StableDiffusionModelEditingPipeline + - __call__ + - all diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 160fa0d2ebce..70731fd294b9 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -35,6 +35,7 @@ For more details about how Stable Diffusion works and how it differs from the ba | [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** – *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix) | [StableDiffusionAttendAndExcitePipeline](./attend_and_excite) | **Experimental** – *Text-to-Image Generation * | | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite) | [StableDiffusionPix2PixZeroPipeline](./pix2pix_zero) | **Experimental** – *Text-Based Image Editing * | | [Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) +| [StableDiffusionModelEditingPipeline](./model_editing) | **Experimental** – *Text-to-Image Model Editing * | | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://arxiv.org/abs/2303.08084) diff --git a/docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx b/docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx index b34c1f51cf66..133f2b775d71 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx @@ -14,25 +14,26 @@ specific language governing permissions and limitations under the License. ## Overview -[Self-Attention Guidance](https://arxiv.org/abs/2210.00939) by Susung Hong et al. +[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) by Susung Hong et al. The abstract of the paper is the following: -*Denoising diffusion models (DDMs) have been drawing much attention for their appreciable sample quality and diversity. Despite their remarkable performance, DDMs remain black boxes on which further study is necessary to take a profound step. Motivated by this, we delve into the design of conventional U-shaped diffusion models. More specifically, we investigate the self-attention modules within these models through carefully designed experiments and explore their characteristics. In addition, inspired by the studies that substantiate the effectiveness of the guidance schemes, we present plug-and-play diffusion guidance, namely Self-Attention Guidance (SAG), that can drastically boost the performance of existing diffusion models. Our method, SAG, extracts the intermediate attention map from a diffusion model at every iteration and selects tokens above a certain attention score for masking and blurring to obtain a partially blurred input. Subsequently, we measure the dissimilarity between the predicted noises obtained from feeding the blurred and original input to the diffusion model and leverage it as guidance. With this guidance, we observe apparent improvements in a wide range of diffusion models, e.g., ADM, IDDPM, and Stable Diffusion, and show that the results further improve by combining our method with the conventional guidance scheme. We provide extensive ablation studies to verify our choices.* +*Denoising diffusion models (DDMs) have attracted attention for their exceptional generation quality and diversity. This success is largely attributed to the use of class- or text-conditional diffusion guidance methods, such as classifier and classifier-free guidance. In this paper, we present a more comprehensive perspective that goes beyond the traditional guidance methods. From this generalized perspective, we introduce novel condition- and training-free strategies to enhance the quality of generated images. As a simple solution, blur guidance improves the suitability of intermediate samples for their fine-scale information and structures, enabling diffusion models to generate higher quality samples with a moderate guidance scale. Improving upon this, Self-Attention Guidance (SAG) uses the intermediate self-attention maps of diffusion models to enhance their stability and efficacy. Specifically, SAG adversarially blurs only the regions that diffusion models attend to at each iteration and guides them accordingly. Our experimental results show that our SAG improves the performance of various diffusion models, including ADM, IDDPM, Stable Diffusion, and DiT. Moreover, combining SAG with conventional guidance methods leads to further improvement.* Resources: * [Project Page](https://ku-cvlab.github.io/Self-Attention-Guidance). * [Paper](https://arxiv.org/abs/2210.00939). * [Original Code](https://github.com/KU-CVLAB/Self-Attention-Guidance). -* [Demo](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb). +* [Hugging Face Demo](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance). +* [Colab Demo](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb). ## Available Pipelines: | Pipeline | Tasks | Demo |---|---|:---:| -| [StableDiffusionSAGPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py) | *Text-to-Image Generation* | [Colab](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb) | +| [StableDiffusionSAGPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py) | *Text-to-Image Generation* | [πŸ€— Space](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance) | ## Usage example diff --git a/docs/source/en/api/pipelines/stable_diffusion_safe.mdx b/docs/source/en/api/pipelines/stable_diffusion_safe.mdx index 900f22badf6f..035c7155ef93 100644 --- a/docs/source/en/api/pipelines/stable_diffusion_safe.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion_safe.mdx @@ -28,15 +28,15 @@ The abstract of the paper is the following: ## Tips -- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion/text2img). +- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./stable_diffusion/text2img). ### Run Safe Stable Diffusion -Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation). +Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](../../using-diffusers/conditional_image_generation). ### Interacting with the Safety Concept -To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`] +To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]: ```python >>> from diffusers import StableDiffusionPipelineSafe @@ -60,7 +60,7 @@ You may use the 4 configurations defined in the [Safe Latent Diffusion paper](ht The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`. -### How to load and use different schedulers. +### How to load and use different schedulers The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx index 40bc3e27af77..ee359d0ba486 100644 --- a/docs/source/en/api/pipelines/stable_unclip.mdx +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -16,6 +16,10 @@ Stable unCLIP checkpoints are finetuned from [stable diffusion 2.1](./stable_dif Stable unCLIP also still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation. +To know more about the unCLIP process, check out the following paper: + +[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. + ## Tips Stable unCLIP takes a `noise_level` as input during inference. `noise_level` determines how much noise is added @@ -24,50 +28,124 @@ we do not add any additional noise to the image embeddings i.e. `noise_level = 0 ### Available checkpoints: -TODO +* Image variation + * [stabilityai/stable-diffusion-2-1-unclip](https://hf.co/stabilityai/stable-diffusion-2-1-unclip) + * [stabilityai/stable-diffusion-2-1-unclip-small](https://hf.co/stabilityai/stable-diffusion-2-1-unclip-small) +* Text-to-image + * [stabilityai/stable-diffusion-2-1-unclip-small](https://hf.co/stabilityai/stable-diffusion-2-1-unclip-small) ### Text-to-Image Generation +Stable unCLIP can be leveraged for text-to-image generation by pipelining it with the prior model of KakaoBrain's open source DALL-E 2 replication [Karlo](https://huggingface.co/kakaobrain/karlo-v1-alpha) ```python import torch -from diffusers import StableUnCLIPPipeline +from diffusers import UnCLIPScheduler, DDPMScheduler, StableUnCLIPPipeline +from diffusers.models import PriorTransformer +from transformers import CLIPTokenizer, CLIPTextModelWithProjection + +prior_model_id = "kakaobrain/karlo-v1-alpha" +data_type = torch.float16 +prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type) + +prior_text_model_id = "openai/clip-vit-large-patch14" +prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id) +prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type) +prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler") +prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + +stable_unclip_model_id = "stabilityai/stable-diffusion-2-1-unclip-small" pipe = StableUnCLIPPipeline.from_pretrained( - "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 -) # TODO update model path + stable_unclip_model_id, + torch_dtype=data_type, + variant="fp16", + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, +) + pipe = pipe.to("cuda") +wave_prompt = "dramatic wave, the Oceans roar, Strong wave spiral across the oceans as the waves unfurl into roaring crests; perfect wave form; perfect wave shape; dramatic wave shape; wave shape unbelievable; wave; wave shape spectacular" -prompt = "a photo of an astronaut riding a horse on mars" -images = pipe(prompt).images -images[0].save("astronaut_horse.png") +images = pipe(prompt=wave_prompt).images +images[0].save("waves.png") ``` + +For text-to-image we use `stabilityai/stable-diffusion-2-1-unclip-small` as it was trained on CLIP ViT-L/14 embedding, the same as the Karlo model prior. [stabilityai/stable-diffusion-2-1-unclip](https://hf.co/stabilityai/stable-diffusion-2-1-unclip) was trained on OpenCLIP ViT-H, so we don't recommend its use. + + ### Text guided Image-to-Image Variation ```python -import requests -import torch -from PIL import Image -from io import BytesIO - from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( - "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 -) # TODO update model path + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) pipe = pipe.to("cuda") -url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0].save("variation_image.png") +``` -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") -init_image = init_image.resize((768, 512)) +Optionally, you can also pass a prompt to `pipe` such as: +```python prompt = "A fantasy landscape, trending on artstation" -images = pipe(prompt, init_image).images -images[0].save("fantasy_landscape.png") +images = pipe(init_image, prompt=prompt).images +images[0].save("variation_image_two.png") +``` + +### Memory optimization + +If you are short on GPU memory, you can enable smart CPU offloading so that models that are not needed +immediately for a computation can be offloaded to CPU: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +# Offload to CPU. +pipe.enable_model_cpu_offload() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] +``` + +Further memory optimizations are possible by enabling VAE slicing on the pipeline: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] ``` ### StableUnCLIPPipeline diff --git a/docs/source/en/api/pipelines/text_to_video.mdx b/docs/source/en/api/pipelines/text_to_video.mdx index f1fe794e1537..82b2f19ce1b2 100644 --- a/docs/source/en/api/pipelines/text_to_video.mdx +++ b/docs/source/en/api/pipelines/text_to_video.mdx @@ -10,25 +10,33 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> + + +This pipeline is for research purposes only. + + + # Text-to-video synthesis -Text-to-video synthesis from [ModelScope](https://modelscope.cn/) can be considered the same as Stable Diffusion structure-wise but it is extended to videos instead of static images. More specifically, this system allows us to generate videos from a natural language text prompt. +## Overview + +[VideoFusion: Decomposed Diffusion Models for High-Quality Video Generation](https://arxiv.org/abs/2303.08320) by Zhengxiong Luo, Dayou Chen, Yingya Zhang, Yan Huang, Liang Wang, Yujun Shen, Deli Zhao, Jingren Zhou, Tieniu Tan. -From the [model summary](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis): +The abstract of the paper is the following: -*This model is based on a multi-stage text-to-video generation diffusion model, which inputs a description text and returns a video that matches the text description. Only English input is supported.* +*A diffusion probabilistic model (DPM), which constructs a forward diffusion process by gradually adding noise to data points and learns the reverse denoising process to generate new samples, has been shown to handle complex data distribution. Despite its recent success in image synthesis, applying DPMs to video generation is still challenging due to high-dimensional data spaces. Previous methods usually adopt a standard diffusion process, where frames in the same video clip are destroyed with independent noises, ignoring the content redundancy and temporal correlation. This work presents a decomposed diffusion process via resolving the per-frame noise into a base noise that is shared among all frames and a residual noise that varies along the time axis. The denoising pipeline employs two jointly-learned networks to match the noise decomposition accordingly. Experiments on various datasets confirm that our approach, termed as VideoFusion, surpasses both GAN-based and diffusion-based alternatives in high-quality video generation. We further show that our decomposed formulation can benefit from pre-trained image diffusion models and well-support text-conditioned video creation.* Resources: * [Website](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) * [GitHub repository](https://github.com/modelscope/modelscope/) -* [Spaces] (TODO) +* [πŸ€— Spaces](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis) ## Available Pipelines: | Pipeline | Tasks | Demo |---|---|:---:| -| [DiffusionPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [Spaces] (TODO) +| [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [πŸ€— Spaces](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis) ## Usage example @@ -116,7 +124,7 @@ Here are some sample outputs: * [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/) * [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy) -## DiffusionPipeline -[[autodoc]] DiffusionPipeline +## TextToVideoSDPipeline +[[autodoc]] TextToVideoSDPipeline - all - __call__ diff --git a/docs/source/en/api/pipelines/text_to_video_zero.mdx b/docs/source/en/api/pipelines/text_to_video_zero.mdx new file mode 100644 index 000000000000..3ee10f01c377 --- /dev/null +++ b/docs/source/en/api/pipelines/text_to_video_zero.mdx @@ -0,0 +1,240 @@ + + +# Zero-Shot Text-to-Video Generation + +## Overview + + +[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) by +Levon Khachatryan, +Andranik Movsisyan, +Vahram Tadevosyan, +Roberto Henschel, +[Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com). + +Our method Text2Video-Zero enables zero-shot video generation using either +1. A textual prompt, or +2. A prompt combined with guidance from poses or edges, or +3. Video Instruct-Pix2Pix, i.e., instruction-guided video editing. + +Results are temporally consistent and follow closely the guidance and textual prompts. + +![teaser-img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2v_zero_teaser.png) + +The abstract of the paper is the following: + +*Recent text-to-video generation approaches rely on computationally heavy training and require large-scale video datasets. In this paper, we introduce a new task of zero-shot text-to-video generation and propose a low-cost approach (without any training or optimization) by leveraging the power of existing text-to-image synthesis methods (e.g., Stable Diffusion), making them suitable for the video domain. +Our key modifications include (i) enriching the latent codes of the generated frames with motion dynamics to keep the global scene and the background time consistent; and (ii) reprogramming frame-level self-attention using a new cross-frame attention of each frame on the first frame, to preserve the context, appearance, and identity of the foreground object. +Experiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing. +As experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.* + + + +Resources: + +* [Project Page](https://text2video-zero.github.io/) +* [Paper](https://arxiv.org/abs/2303.13439) +* [Original Code](https://github.com/Picsart-AI-Research/Text2Video-Zero) + + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [TextToVideoZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py) | *Zero-shot Text-to-Video Generation* | [πŸ€— Space](https://huggingface.co/spaces/PAIR/Text2Video-Zero) + + +## Usage example + +### Text-To-Video + +To generate a video from prompt, run the following python command +```python +import torch +import imageio +from diffusers import TextToVideoZeroPipeline + +model_id = "runwayml/stable-diffusion-v1-5" +pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + +prompt = "A panda is playing guitar on times square" +result = pipe(prompt=prompt).images +result = [(r * 255).astype("uint8") for r in result] +imageio.mimsave("video.mp4", result, fps=4) +``` +You can change these parameters in the pipeline call: +* Motion field strength (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1): + * `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12` +* `T` and `T'` (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1) + * `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48` +* Video length: + * `video_length`, the number of frames video_length to be generated. Default: `video_length=8` + + +### Text-To-Video with Pose Control +To generate a video from prompt with additional pose control + +1. Download a demo video + + ```python + from huggingface_hub import hf_hub_download + + filename = "__assets__/poses_skeleton_gifs/dance1_corr.mp4" + repo_id = "PAIR/Text2Video-Zero" + video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename) + ``` + + +2. Read video containing extracted pose images + ```python + from PIL import Image + import imageio + + reader = imageio.get_reader(video_path, "ffmpeg") + frame_count = 8 + pose_images = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)] + ``` + To extract pose from actual video, read [ControlNet documentation](./stable_diffusion/controlnet). + +3. Run `StableDiffusionControlNetPipeline` with our custom attention processor + + ```python + import torch + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor + + model_id = "runwayml/stable-diffusion-v1-5" + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + model_id, controlnet=controlnet, torch_dtype=torch.float16 + ).to("cuda") + + # Set the attention processor + pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + + # fix latents for all frames + latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1) + + prompt = "Darth Vader dancing in a desert" + result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images + imageio.mimsave("video.mp4", result, fps=4) + ``` + + +### Text-To-Video with Edge Control + +To generate a video from prompt with additional pose control, +follow the steps described above for pose-guided generation using [Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny). + + +### Video Instruct-Pix2Pix + +To perform text-guided video editing (with [InstructPix2Pix](./stable_diffusion/pix2pix)): + +1. Download a demo video + + ```python + from huggingface_hub import hf_hub_download + + filename = "__assets__/pix2pix video/camel.mp4" + repo_id = "PAIR/Text2Video-Zero" + video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename) + ``` + +2. Read video from path + ```python + from PIL import Image + import imageio + + reader = imageio.get_reader(video_path, "ffmpeg") + frame_count = 8 + video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)] + ``` + +3. Run `StableDiffusionInstructPix2PixPipeline` with our custom attention processor + ```python + import torch + from diffusers import StableDiffusionInstructPix2PixPipeline + from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor + + model_id = "timbrooks/instruct-pix2pix" + pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=3)) + + prompt = "make it Van Gogh Starry Night style" + result = pipe(prompt=[prompt] * len(video), image=video).images + imageio.mimsave("edited_video.mp4", result, fps=4) + ``` + + +### DreamBooth specialization + +Methods **Text-To-Video**, **Text-To-Video with Pose Control** and **Text-To-Video with Edge Control** +can run with custom [DreamBooth](../training/dreambooth) models, as shown below for +[Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny) and +[Avatar style DreamBooth](https://huggingface.co/PAIR/text2video-zero-controlnet-canny-avatar) model + +1. Download a demo video + + ```python + from huggingface_hub import hf_hub_download + + filename = "__assets__/canny_videos_mp4/girl_turning.mp4" + repo_id = "PAIR/Text2Video-Zero" + video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename) + ``` + +2. Read video from path + ```python + from PIL import Image + import imageio + + reader = imageio.get_reader(video_path, "ffmpeg") + frame_count = 8 + video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)] + ``` + +3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model + ```python + import torch + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor + + # set model id to custom model + model_id = "PAIR/text2video-zero-controlnet-canny-avatar" + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + model_id, controlnet=controlnet, torch_dtype=torch.float16 + ).to("cuda") + + # Set the attention processor + pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + + # fix latents for all frames + latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1) + + prompt = "oil painting of a beautiful girl avatar style" + result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images + imageio.mimsave("video.mp4", result, fps=4) + ``` + +You can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth). + + + +## TextToVideoZeroPipeline +[[autodoc]] TextToVideoZeroPipeline + - all + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/versatile_diffusion.mdx b/docs/source/en/api/pipelines/versatile_diffusion.mdx index bfafa8e8f1fc..f87fdc93e36e 100644 --- a/docs/source/en/api/pipelines/versatile_diffusion.mdx +++ b/docs/source/en/api/pipelines/versatile_diffusion.mdx @@ -20,7 +20,7 @@ The abstract of the paper is the following: ## Tips -- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion/overview), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image. +- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./stable_diffusion/overview), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image. ### *Run VersatileDiffusion* diff --git a/docs/source/en/api/schedulers/ddim.mdx b/docs/source/en/api/schedulers/ddim.mdx index dc9bdd59a03e..51b0cc3e9a09 100644 --- a/docs/source/en/api/schedulers/ddim.mdx +++ b/docs/source/en/api/schedulers/ddim.mdx @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Denoising diffusion implicit models (DDIM) +# Denoising Diffusion Implicit Models (DDIM) ## Overview @@ -24,4 +24,4 @@ The original codebase of this paper can be found here: [ermongroup/ddim](https:/ For questions, feel free to contact the author on [tsong.me](https://tsong.me/). ## DDIMScheduler -[[autodoc]] DDIMScheduler \ No newline at end of file +[[autodoc]] DDIMScheduler diff --git a/docs/source/en/api/schedulers/ddpm.mdx b/docs/source/en/api/schedulers/ddpm.mdx index 76ea248a01a8..6c4058b941fa 100644 --- a/docs/source/en/api/schedulers/ddpm.mdx +++ b/docs/source/en/api/schedulers/ddpm.mdx @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Denoising diffusion probabilistic models (DDPM) +# Denoising Diffusion Probabilistic Models (DDPM) ## Overview @@ -24,4 +24,4 @@ We present high quality image synthesis results using diffusion probabilistic mo The original paper can be found [here](https://arxiv.org/abs/2010.02502). ## DDPMScheduler -[[autodoc]] DDPMScheduler \ No newline at end of file +[[autodoc]] DDPMScheduler diff --git a/docs/source/en/api/schedulers/euler_ancestral.mdx b/docs/source/en/api/schedulers/euler_ancestral.mdx index 0fc74f471633..60fd524b1955 100644 --- a/docs/source/en/api/schedulers/euler_ancestral.mdx +++ b/docs/source/en/api/schedulers/euler_ancestral.mdx @@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License. ## Overview -Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson. +Ancestral sampling with Euler method steps. Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) implementation by Katherine Crowson. Fast scheduler which often times generates good outputs with 20-30 steps. ## EulerAncestralDiscreteScheduler -[[autodoc]] EulerAncestralDiscreteScheduler \ No newline at end of file +[[autodoc]] EulerAncestralDiscreteScheduler diff --git a/docs/source/en/api/schedulers/score_sde_ve.mdx b/docs/source/en/api/schedulers/score_sde_ve.mdx index 0906227229ea..66a00c69e3b4 100644 --- a/docs/source/en/api/schedulers/score_sde_ve.mdx +++ b/docs/source/en/api/schedulers/score_sde_ve.mdx @@ -10,11 +10,11 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# variance exploding stochastic differential equation (VE-SDE) scheduler +# Variance Exploding Stochastic Differential Equation (VE-SDE) scheduler ## Overview Original paper can be found [here](https://arxiv.org/abs/2011.13456). ## ScoreSdeVeScheduler -[[autodoc]] ScoreSdeVeScheduler \ No newline at end of file +[[autodoc]] ScoreSdeVeScheduler diff --git a/docs/source/en/api/schedulers/score_sde_vp.mdx b/docs/source/en/api/schedulers/score_sde_vp.mdx index 19a628256e6a..ac1d2f109c81 100644 --- a/docs/source/en/api/schedulers/score_sde_vp.mdx +++ b/docs/source/en/api/schedulers/score_sde_vp.mdx @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Variance preserving stochastic differential equation (VP-SDE) scheduler +# Variance Preserving Stochastic Differential Equation (VP-SDE) scheduler ## Overview @@ -23,4 +23,4 @@ Score SDE-VP is under construction. ## ScoreSdeVpScheduler -[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler \ No newline at end of file +[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler diff --git a/docs/source/en/api/schedulers/unipc.mdx b/docs/source/en/api/schedulers/unipc.mdx index 1ed49b7727fc..134dc1ef3170 100644 --- a/docs/source/en/api/schedulers/unipc.mdx +++ b/docs/source/en/api/schedulers/unipc.mdx @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License. UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. -For more details about the method, please refer to the [[paper]](https://arxiv.org/abs/2302.04867) and the [[code]](https://github.com/wl-zhao/UniPC). +For more details about the method, please refer to the [paper](https://arxiv.org/abs/2302.04867) and the [code](https://github.com/wl-zhao/UniPC). Fast Sampling of Diffusion Models with Exponential Integrator. diff --git a/docs/source/en/conceptual/contribution.mdx b/docs/source/en/conceptual/contribution.mdx index e0e873892ca2..7b78d318b679 100644 --- a/docs/source/en/conceptual/contribution.mdx +++ b/docs/source/en/conceptual/contribution.mdx @@ -170,7 +170,7 @@ please have a look at the next sections. For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull requst](#how-to-open-a-pr) section. -### 4. Fixing a "Good first issue" +### 4. Fixing a `Good first issue` *Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already explains how a potential solution should look so that it is easier to fix. @@ -275,7 +275,7 @@ Once an example script works, please make sure to add a comprehensive `README.md If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples. -### 8. Fixing a "Good second issue" +### 8. Fixing a `Good second issue` *Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). @@ -439,7 +439,7 @@ Push the changes to your account using: $ git push -u origin a-descriptive-name-for-my-changes ``` -6. Once you are satisfied (**and the checklist below is happy too**), go to the +6. Once you are satisfied, go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. diff --git a/docs/source/en/conceptual/evaluation.mdx b/docs/source/en/conceptual/evaluation.mdx index 98821010e203..2721adea0c16 100644 --- a/docs/source/en/conceptual/evaluation.mdx +++ b/docs/source/en/conceptual/evaluation.mdx @@ -310,7 +310,7 @@ for idx in range(len(dataset)): edited_images.append(edited_image) ``` -To measure the directional similarity, we first load CLIP's image and text encoders. +To measure the directional similarity, we first load CLIP's image and text encoders: ```python from transformers import ( @@ -329,7 +329,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device Notice that we are using a particular CLIP checkpoint, i.e.,Β `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to theΒ [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix#diffusers.StableDiffusionInstructPix2PixPipeline.text_encoder). -Next, we prepare a PyTorchΒ `nn.module`Β to compute directional similarity: +Next, we prepare a PyTorchΒ `nn.Module`Β to compute directional similarity: ```python import torch.nn as nn @@ -410,7 +410,7 @@ It should be noted that theΒ `StableDiffusionInstructPix2PixPipeline`Β exposes t We can extend the idea of this metric to measure how similar the original image and edited version are. To do that, we can just doΒ `F.cosine_similarity(img_feat_two, img_feat_one)`. For these kinds of edits, we would still want the primary semantics of the images to be preserved as much as possible, i.e., a high similarity score. -We can use these metrics for similar pipelines such as the[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)`. +We can use these metrics for similar pipelines such as the [`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline). @@ -550,7 +550,7 @@ FID results tend to be fragile as they depend on a lot of factors: * The image format (not the same if we start from PNGs vs JPGs). Keeping that in mind, FID is often most useful when comparing similar runs, but it is -hard to to reproduce paper results unless the authors carefully disclose the FID +hard to reproduce paper results unless the authors carefully disclose the FID measurement code. These points apply to other related metrics too, such as KID and IS. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 2ccabb1b32ee..10a237f29278 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -73,9 +73,10 @@ The library has three main components: | [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://arxiv.org/abs/2211.09800) | Text-Guided Image Editing| | [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [Zero-shot Image-to-Image Translation](https://pix2pixzero.github.io/) | Text-Guided Image Editing | | [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://arxiv.org/abs/2301.13826) | Text-to-Image Generation | -| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) | Text-to-Image Generation | +| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) | Text-to-Image Generation Unconditional Image Generation | | [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation | | [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [Stable Diffusion Latent Upscaler](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_model_editing](./api/pipelines/stable_diffusion/model_editing) | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://time-diffusion.github.io/) | Text-to-Image Model Editing | | [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | | [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | | [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Depth-Conditional Stable Diffusion](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation | diff --git a/docs/source/en/optimization/coreml.mdx b/docs/source/en/optimization/coreml.mdx new file mode 100644 index 000000000000..ab96eea0fb04 --- /dev/null +++ b/docs/source/en/optimization/coreml.mdx @@ -0,0 +1,167 @@ + + +# How to run Stable Diffusion with Core ML + +[Core ML](https://developer.apple.com/documentation/coreml) is the model format and machine learning library supported by Apple frameworks. If you are interested in running Stable Diffusion models inside your macOS or iOS/iPadOS apps, this guide will show you how to convert existing PyTorch checkpoints into the Core ML format and use them for inference with Python or Swift. + +Core ML models can leverage all the compute engines available in Apple devices: the CPU, the GPU, and the Apple Neural Engine (or ANE, a tensor-optimized accelerator available in Apple Silicon Macs and modern iPhones/iPads). Depending on the model and the device it's running on, Core ML can mix and match compute engines too, so some portions of the model may run on the CPU while others run on GPU, for example. + + + +You can also run the `diffusers` Python codebase on Apple Silicon Macs using the `mps` accelerator built into PyTorch. This approach is explained in depth in [the mps guide](mps), but it is not compatible with native apps. + + + +## Stable Diffusion Core ML Checkpoints + +Stable Diffusion weights (or checkpoints) are stored in the PyTorch format, so you need to convert them to the Core ML format before we can use them inside native apps. + +Thankfully, Apple engineers developed [a conversion tool](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) based on `diffusers` to convert the PyTorch checkpoints to Core ML. + +Before you convert a model, though, take a moment to explore the Hugging Face Hub – chances are the model you're interested in is already available in Core ML format: + +- the [Apple](https://huggingface.co/apple) organization includes Stable Diffusion versions 1.4, 1.5, 2.0 base, and 2.1 base +- [coreml](https://huggingface.co/coreml) organization includes custom DreamBoothed and finetuned models +- use this [filter](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) to return all available Core ML checkpoints + +If you can't find the model you're interested in, we recommend you follow the instructions for [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) by Apple. + +## Selecting the Core ML Variant to Use + +Stable Diffusion models can be converted to different Core ML variants intended for different purposes: + +- The type of attention blocks used. The attention operation is used to "pay attention" to the relationship between different areas in the image representations and to understand how the image and text representations are related. Attention is compute- and memory-intensive, so different implementations exist that consider the hardware characteristics of different devices. For Core ML Stable Diffusion models, there are two attention variants: + * `split_einsum` ([introduced by Apple](https://machinelearning.apple.com/research/neural-engine-transformers)) is optimized for ANE devices, which is available in modern iPhones, iPads and M-series computers. + * The "original" attention (the base implementation used in `diffusers`) is only compatible with CPU/GPU and not ANE. It can be *faster* to run your model on CPU + GPU using `original` attention than ANE. See [this performance benchmark](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks) as well as some [additional measures provided by the community](https://github.com/huggingface/swift-coreml-diffusers/issues/31) for additional details. + +- The supported inference framework. + * `packages` are suitable for Python inference. This can be used to test converted Core ML models before attempting to integrate them inside native apps, or if you want to explore Core ML performance but don't need to support native apps. For example, an application with a web UI could perfectly use a Python Core ML backend. + * `compiled` models are required for Swift code. The `compiled` models in the Hub split the large UNet model weights into several files for compatibility with iOS and iPadOS devices. This corresponds to the [`--chunk-unet` conversion option](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml). If you want to support native apps, then you need to select the `compiled` variant. + +The official Core ML Stable Diffusion [models](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main) include these variants, but the community ones may vary: + +``` +coreml-stable-diffusion-v1-4 +β”œβ”€β”€ README.md +β”œβ”€β”€ original +β”‚ β”œβ”€β”€ compiled +β”‚ └── packages +└── split_einsum + β”œβ”€β”€ compiled + └── packages +``` + +You can download and use the variant you need as shown below. + +## Core ML Inference in Python + +Install the following libraries to run Core ML inference in Python: + +```bash +pip install huggingface_hub +pip install git+https://github.com/apple/ml-stable-diffusion +``` + +### Download the Model Checkpoints + +To run inference in Python, use one of the versions stored in the `packages` folders because the `compiled` ones are only compatible with Swift. You may choose whether you want to use `original` or `split_einsum` attention. + +This is how you'd download the `original` attention variant from the Hub to a directory called `models`: + +```Python +from huggingface_hub import snapshot_download +from pathlib import Path + +repo_id = "apple/coreml-stable-diffusion-v1-4" +variant = "original/packages" + +model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_")) +snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False) +print(f"Model downloaded at {model_path}") +``` + + +### Inference[[python-inference]] + +Once you have downloaded a snapshot of the model, you can test it using Apple's Python script. + +```shell +python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i models/coreml-stable-diffusion-v1-4_original_packages -o --compute-unit CPU_AND_GPU --seed 93 +``` + +`` should point to the checkpoint you downloaded in the step above, and `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility. + +The inference script assumes you're using the original version of the Stable Diffusion model, `CompVis/stable-diffusion-v1-4`. If you use another model, you *have* to specify its Hub id in the inference command line, using the `--model-version` option. This works for models already supported and custom models you trained or fine-tuned yourself. + +For example, if you want to use [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5): + +```shell +python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version runwayml/stable-diffusion-v1-5 +``` + + +## Core ML inference in Swift + +Running inference in Swift is slightly faster than in Python because the models are already compiled in the `mlmodelc` format. This is noticeable on app startup when the model is loaded but shouldn’t be noticeable if you run several generations afterward. + +### Download + +To run inference in Swift on your Mac, you need one of the `compiled` checkpoint versions. We recommend you download them locally using Python code similar to the previous example, but with one of the `compiled` variants: + +```Python +from huggingface_hub import snapshot_download +from pathlib import Path + +repo_id = "apple/coreml-stable-diffusion-v1-4" +variant = "original/compiled" + +model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_")) +snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False) +print(f"Model downloaded at {model_path}") +``` + +### Inference[[swift-inference]] + +To run inference, please clone Apple's repo: + +```bash +git clone https://github.com/apple/ml-stable-diffusion +cd ml-stable-diffusion +``` + +And then use Apple's command line tool, [Swift Package Manager](https://www.swift.org/package-manager/#): + +```bash +swift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all "a photo of an astronaut riding a horse on mars" +``` + +You have to specify in `--resource-path` one of the checkpoints downloaded in the previous step, so please make sure it contains compiled Core ML bundles with the extension `.mlmodelc`. The `--compute-units` has to be one of these values: `all`, `cpuOnly`, `cpuAndGPU`, `cpuAndNeuralEngine`. + +For more details, please refer to the [instructions in Apple's repo](https://github.com/apple/ml-stable-diffusion). + + +## Supported Diffusers Features + +The Core ML models and inference code don't support many of the features, options, and flexibility of 🧨 Diffusers. These are some of the limitations to keep in mind: + +- Core ML models are only suitable for inference. They can't be used for training or fine-tuning. +- Only two schedulers have been ported to Swift, the default one used by Stable Diffusion and `DPMSolverMultistepScheduler`, which we ported to Swift from our `diffusers` implementation. We recommend you use `DPMSolverMultistepScheduler`, since it produces the same quality in about half the steps. +- Negative prompts, classifier-free guidance scale, and image-to-image tasks are available in the inference code. Advanced features such as depth guidance, ControlNet, and latent upscalers are not available yet. + +Apple's [conversion and inference repo](https://github.com/apple/ml-stable-diffusion) and our own [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) repos are intended as technology demonstrators to enable other developers to build upon. + +If you feel strongly about any missing features, please feel free to open a feature request or, better yet, a contribution PR :) + +## Native Diffusers Swift app + +One easy way to run Stable Diffusion on your own Apple hardware is to use [our open-source Swift repo](https://github.com/huggingface/swift-coreml-diffusers), based on `diffusers` and Apple's conversion and inference repo. You can study the code, compile it with [Xcode](https://developer.apple.com/xcode/) and adapt it for your own needs. For your convenience, there's also a [standalone Mac app in the App Store](https://apps.apple.com/app/diffusers/id1666309574), so you can play with it without having to deal with the code or IDE. If you are a developer and have determined that Core ML is the best solution to build your Stable Diffusion app, then you can use the rest of this guide to get started with your project. We can't wait to see what you'll build :) diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index c18cefbde6a9..d05c5aabea2b 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -19,7 +19,6 @@ We'll discuss how the following settings impact performance and memory. | | Latency | Speedup | | ---------------- | ------- | ------- | | original | 9.50s | x1 | -| cuDNN auto-tuner | 9.37s | x1.01 | | fp16 | 3.61s | x2.63 | | channels last | 3.30s | x2.88 | | traced UNet | 3.21s | x2.96 | @@ -31,18 +30,6 @@ We'll discuss how the following settings impact performance and memory. steps. -## Enable cuDNN auto-tuner - -[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)Β supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size. - -Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting: - -```python -import torch - -torch.backends.cudnn.benchmark = True -``` - ### Use tf32 instead of fp32 (on Ampere and later CUDA devices) On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference: @@ -58,7 +45,10 @@ torch.backends.cuda.matmul.allow_tf32 = True To save more GPU memory and get more speed, you can load and run the model weights directly in half precision. This involves loading the float16 version of the weights, which was saved to a branch named `fp16`, and telling PyTorch to use the `float16` type when loading them: ```Python -pipe = StableDiffusionPipeline.from_pretrained( +import torch +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, @@ -85,13 +75,13 @@ For even additional memory savings, you can use a sliced version of attention th each head which can save a significant amount of memory. -To perform the attention computation sequentially over each head, you only need to invoke [`~StableDiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here: +To perform the attention computation sequentially over each head, you only need to invoke [`~DiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here: ```Python import torch -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained( +pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, @@ -415,10 +405,10 @@ To leverage it just make sure you have: - Cuda available - [Installed the xformers library](xformers). ```python -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline import torch -pipe = StableDiffusionPipeline.from_pretrained( +pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, ).to("cuda") diff --git a/docs/source/en/optimization/mps.mdx b/docs/source/en/optimization/mps.mdx index 3750724bce57..3be8c621ee3e 100644 --- a/docs/source/en/optimization/mps.mdx +++ b/docs/source/en/optimization/mps.mdx @@ -35,9 +35,9 @@ The snippet below demonstrates how to use the `mps` backend using the familiar ` We strongly recommend you use PyTorch 2 or better, as it solves a number of problems like the one described in the previous tip. ```python -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("mps") # Recommended if your computer has < 64 GB of RAM diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index a6a40469e97b..206ac4e447cc 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -35,9 +35,9 @@ pip install --upgrade torch torchvision diffusers ```Python import torch - from diffusers import StableDiffusionPipeline + from diffusers import DiffusionPipeline - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" @@ -48,10 +48,10 @@ pip install --upgrade torch torchvision diffusers ```Python import torch - from diffusers import StableDiffusionPipeline + from diffusers import DiffusionPipeline from diffusers.models.attention_processor import AttnProcessor2_0 - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") + pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe.unet.set_attn_processor(AttnProcessor2_0()) prompt = "a photo of an astronaut riding a horse on mars" @@ -68,11 +68,9 @@ pip install --upgrade torch torchvision diffusers ```python import torch - from diffusers import StableDiffusionPipeline + from diffusers import DiffusionPipeline - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( - "cuda" - ) + pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe.unet = torch.compile(pipe.unet) batch_size = 10 diff --git a/docs/source/en/quicktour.mdx b/docs/source/en/quicktour.mdx index 3aecb422af2a..d494b79dccd5 100644 --- a/docs/source/en/quicktour.mdx +++ b/docs/source/en/quicktour.mdx @@ -141,7 +141,7 @@ Different schedulers come with different denoising speeds and quality trade-offs ```py >>> from diffusers import EulerDiscreteScheduler ->>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) ``` diff --git a/docs/source/en/stable_diffusion.mdx b/docs/source/en/stable_diffusion.mdx index 8190813e488a..eebe0ec660f2 100644 --- a/docs/source/en/stable_diffusion.mdx +++ b/docs/source/en/stable_diffusion.mdx @@ -1,333 +1,271 @@ - - -# The Stable Diffusion Guide 🎨 - - Open In Colab - - -## Intro - -Stable Diffusion is a [Latent Diffusion model](https://github.com/CompVis/latent-diffusion) developed by researchers from the Machine Vision and Learning group at LMU Munich, *a.k.a* CompVis. -Model checkpoints were publicly released at the end of August 2022 by a collaboration of Stability AI, CompVis, and Runway with support from EleutherAI and LAION. For more information, you can check out [the official blog post](https://stability.ai/blog/stable-diffusion-public-release). - -Since its public release the community has done an incredible job at working together to make the stable diffusion checkpoints **faster**, **more memory efficient**, and **more performant**. - -🧨 Diffusers offers a simple API to run stable diffusion with all memory, computing, and quality improvements. - -This notebook walks you through the improvements one-by-one so you can best leverage [`StableDiffusionPipeline`] for **inference**. - -## Prompt Engineering 🎨 - -When running *Stable Diffusion* in inference, we usually want to generate a certain type, or style of image and then improve upon it. Improving upon a previously generated image means running inference over and over again with a different prompt and potentially a different seed until we are happy with our generation. - -So to begin with, it is most important to speed up stable diffusion as much as possible to generate as many pictures as possible in a given amount of time. - -This can be done by both improving the **computational efficiency** (speed) and the **memory efficiency** (GPU RAM). - -Let's start by looking into computational efficiency first. - -Throughout the notebook, we will focus on [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5): - -``` python -model_id = "runwayml/stable-diffusion-v1-5" -``` - -Let's load the pipeline. - -## Speed Optimization - -``` python -from diffusers import StableDiffusionPipeline - -pipe = StableDiffusionPipeline.from_pretrained(model_id) -``` - -We aim at generating a beautiful photograph of an *old warrior chief* and will later try to find the best prompt to generate such a photograph. For now, let's keep the prompt simple: - -``` python -prompt = "portrait photo of a old warrior chief" -``` - -To begin with, we should make sure we run inference on GPU, so let's move the pipeline to GPU, just like you would with any PyTorch module. - -``` python -pipe = pipe.to("cuda") -``` - -To generate an image, you should use the [~`StableDiffusionPipeline.__call__`] method. - -To make sure we can reproduce more or less the same image in every call, let's make use of the generator. See the documentation on reproducibility [here](./conceptual/reproducibility) for more information. - -``` python -generator = torch.Generator("cuda").manual_seed(0) -``` - -Now, let's take a spin on it. - -``` python -image = pipe(prompt, generator=generator).images[0] -image -``` - -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_1.png) - -Cool, this now took roughly 30 seconds on a T4 GPU (you might see faster inference if your allocated GPU is better than a T4). - -The default run we did above used full float32 precision and ran the default number of inference steps (50). The easiest speed-ups come from switching to float16 (or half) precision and simply running fewer inference steps. Let's load the model now in float16 instead. - -``` python -import torch - -pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) -pipe = pipe.to("cuda") -``` - -And we can again call the pipeline to generate an image. - -``` python -generator = torch.Generator("cuda").manual_seed(0) - -image = pipe(prompt, generator=generator).images[0] -image -``` -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_2.png) - -Cool, this is almost three times as fast for arguably the same image quality. - -We strongly suggest always running your pipelines in float16 as so far we have very rarely seen degradations in quality because of it. - -Next, let's see if we need to use 50 inference steps or whether we could use significantly fewer. The number of inference steps is associated with the denoising scheduler we use. Choosing a more efficient scheduler could help us decrease the number of steps. - -Let's have a look at all the schedulers the stable diffusion pipeline is compatible with. - -``` python -pipe.scheduler.compatibles -``` - -``` - [diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler, - diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, - diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler, - diffusers.schedulers.scheduling_pndm.PNDMScheduler, - diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, - diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler, - diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, - diffusers.schedulers.scheduling_ddpm.DDPMScheduler, - diffusers.schedulers.scheduling_ddim.DDIMScheduler] -``` - -Cool, that's a lot of schedulers. - -🧨 Diffusers is constantly adding a bunch of novel schedulers/samplers that can be used with Stable Diffusion. For more information, we recommend taking a look at the official documentation [here](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview). - -Alright, right now Stable Diffusion is using the `PNDMScheduler` which usually requires around 50 inference steps. However, other schedulers such as `DPMSolverMultistepScheduler` or `DPMSolverSinglestepScheduler` seem to get away with just 20 to 25 inference steps. Let's try them out. - -You can set a new scheduler by making use of the [from_config](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) function. - -``` python -from diffusers import DPMSolverMultistepScheduler - -pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) -``` - -Now, let's try to reduce the number of inference steps to just 20. - -``` python -generator = torch.Generator("cuda").manual_seed(0) - -image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] -image -``` - -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_3.png) - -The image now does look a little different, but it's arguably still of equally high quality. We now cut inference time to just 4 seconds though 😍. - -## Memory Optimization + + +# Effective and efficient diffusion -``` python -def get_inputs(batch_size=1): - generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)] - prompts = batch_size * [prompt] - num_inference_steps = 20 +[[open-in-colab]] - return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps} -``` -This function returns a list of prompts and a list of generators, so we can reuse the generator that produced a result we like. +Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again. -We also need a method that allows us to easily display a batch of images. +This is why it's important to get the most *computational* (speed) and *memory* (GPU RAM) efficiency from the pipeline to reduce the time between inference cycles so you can iterate faster. -``` python -from PIL import Image +This tutorial walks you through how to generate faster and better with the [`DiffusionPipeline`]. -def image_grid(imgs, rows=2, cols=2): - w, h = imgs[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i%cols*w, i//cols*h)) - return grid -``` +Begin by loading the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model: -Cool, let's see how much memory we can use starting with `batch_size=4`. +```python +from diffusers import DiffusionPipeline -``` python -images = pipe(**get_inputs(batch_size=4)).images -image_grid(images) -``` +model_id = "runwayml/stable-diffusion-v1-5" +pipeline = DiffusionPipeline.from_pretrained(model_id) +``` -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_4.png) +The example prompt you'll use is a portrait of an old warrior chief, but feel free to use your own prompt: -Going over a batch_size of 4 will error out in this notebook (assuming we are running it on a T4 GPU). Also, we can see we only generate slightly more images per second (3.75s/image) compared to 4s/image previously. +```python +prompt = "portrait photo of a old warrior chief" +``` -However, the community has found some nice tricks to improve the memory constraints further. After stable diffusion was released, the community found improvements within days and shared them freely over GitHub - open-source at its finest! I believe the original idea came from [this](https://github.com/basujindal/stable-diffusion/pull/117) GitHub thread. +## Speed -By far most of the memory is taken up by the cross-attention layers. Instead of running this operation in batch, one can run it sequentially to save a significant amount of memory. + -It can easily be enabled by calling `enable_attention_slicing` as is documented [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.enable_attention_slicing). +πŸ’‘ If you don't have access to a GPU, you can use one for free from a GPU provider like [Colab](https://colab.research.google.com/)! -``` python -pipe.enable_attention_slicing() -``` + -Great, now that attention slicing is enabled, let's try to double the batch size again, going for `batch_size=8`. +One of the simplest ways to speed up inference is to place the pipeline on a GPU the same way you would with any PyTorch module: -``` python -images = pipe(**get_inputs(batch_size=8)).images -image_grid(images, rows=2, cols=4) -``` +```python +pipeline = pipeline.to("cuda") +``` -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_5.png) +To make sure you can use the same image and improve on it, use a [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed for [reproducibility](./using-diffusers/reproducibility): -Nice, it works. However, the speed gain is again not very big (it might however be much more significant on other GPUs). +```python +generator = torch.Generator("cuda").manual_seed(0) +``` -We're at roughly 3.5 seconds per image πŸ”₯ which is probably the fastest we can be with a simple T4 without sacrificing quality. +Now you can generate an image: -Next, let's look into how to improve the quality! +```python +image = pipeline(prompt, generator=generator).images[0] +image +``` -## Quality Improvements +
+ +
-Now that our image generation pipeline is blazing fast, let's try to get maximum image quality. +This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps. -First of all, image quality is extremely subjective, so it's difficult to make general claims here. +Let's start by loading the model in `float16` and generate an image: -The most obvious step to take to improve quality is to use *better checkpoints*. Since the release of Stable Diffusion, many improved versions have been released, which are summarized here: +```python +import torch -- *Official Release - 22 Aug 2022*: [Stable-Diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4) -- *20 October 2022*: [Stable-Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) -- *24 Nov 2022*: [Stable-Diffusion 2.0](https://huggingface.co/stabilityai/stable-diffusion-2-0) -- *7 Dec 2022*: [Stable-Diffusion 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) +pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) +pipeline = pipeline.to("cuda") +generator = torch.Generator("cuda").manual_seed(0) +image = pipeline(prompt, generator=generator).images[0] +image +``` -Newer versions don't necessarily mean better image quality with the same parameters. People mentioned that *2.0* is slightly worse than *1.5* for certain prompts, but given the right prompt engineering *2.0* and *2.1* seem to be better. +
+ +
-Overall, we strongly recommend just trying the models out and reading up on advice online (e.g. it has been shown that using negative prompts is very important for 2.0 and 2.1 to get the highest possible quality. See for example [this nice blog post](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/). +This time, it only took ~11 seconds to generate the image, which is almost 3x faster than before! -Additionally, the community has started fine-tuning many of the above versions on certain styles with some of them having an extremely high quality and gaining a lot of traction. + -We recommend having a look at all [diffusers checkpoints sorted by downloads and trying out the different checkpoints](https://huggingface.co/models?library=diffusers). +πŸ’‘ We strongly suggest always running your pipelines in `float16`, and so far, we've rarely seen any degradation in output quality. -For the following, we will stick to v1.5 for simplicity. + -Next, we can also try to optimize single components of the pipeline, e.g. switching out the latent decoder. For more details on how the whole Stable Diffusion pipeline works, please have a look at [this blog post](https://huggingface.co/blog/stable_diffusion). +Another option is to reduce the number of inference steps. Choosing a more efficient scheduler could help decrease the number of steps without sacrificing output quality. You can find which schedulers are compatible with the current model in the [`DiffusionPipeline`] by calling the `compatibles` method: -Let's load [stabilityai's newest auto-decoder](https://huggingface.co/stabilityai/stable-diffusion-2-1). +```python +pipeline.scheduler.compatibles +[ + diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, + diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler, + diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler, + diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler, + diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, + diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, + diffusers.schedulers.scheduling_ddpm.DDPMScheduler, + diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler, + diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler, + diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler, + diffusers.schedulers.scheduling_pndm.PNDMScheduler, + diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler, + diffusers.schedulers.scheduling_ddim.DDIMScheduler, +] +``` -``` python -from diffusers import AutoencoderKL +The Stable Diffusion model uses the [`PNDMScheduler`] by default which usually requires ~50 inference steps, but more performant schedulers like [`DPMSolverMultistepScheduler`], require only ~20 or 25 inference steps. Use the [`ConfigMixin.from_config`] method to load a new scheduler: -vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda") -``` +```python +from diffusers import DPMSolverMultistepScheduler -Now we can set it to the vae of the pipeline to use it. +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) +``` -``` python -pipe.vae = vae -``` +Now set the `num_inference_steps` to 20: -Let's run the same prompt as before to compare quality. +```python +generator = torch.Generator("cuda").manual_seed(0) +image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0] +image +``` -``` python -images = pipe(**get_inputs(batch_size=8)).images -image_grid(images, rows=2, cols=4) -``` +
+ +
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_6.png) +Great, you've managed to cut the inference time to just 4 seconds! ⚑️ -Seems like the difference is only very minor, but the new generations are arguably a bit *sharper*. +## Memory -Cool, finally, let's look a bit into prompt engineering. +The other key to improving pipeline performance is consuming less memory, which indirectly implies more speed, since you're often trying to maximize the number of images generated per second. The easiest way to see how many images you can generate at once is to try out different batch sizes until you get an `OutOfMemoryError` (OOM). -Our goal was to generate a photo of an old warrior chief. Let's now try to bring a bit more color into the photos and make the look more impressive. +Create a function that'll generate a batch of images from a list of prompts and `Generators`. Make sure to assign each `Generator` a seed so you can reuse it if it produces a good result. -Originally our prompt was "*portrait photo of an old warrior chief*". +```python +def get_inputs(batch_size=1): + generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)] + prompts = batch_size * [prompt] + num_inference_steps = 20 -To improve the prompt, it often helps to add cues that could have been used online to save high-quality photos, as well as add more details. -Essentially, when doing prompt engineering, one has to think: + return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps} +``` -- How was the photo or similar photos of the one I want probably stored on the internet? -- What additional detail can I give that steers the models into the style that I want? +You'll also need a function that'll display each batch of images: -Cool, let's add more details. +```python +from PIL import image -``` python -prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes" -``` -and let's also add some cues that usually help to generate higher quality images. +def image_grid(imgs, rows=2, cols=2): + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) -``` python -prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta" -prompt -``` + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid +``` -Cool, let's now try this prompt. +Start with `batch_size=4` and see how much memory you've consumed: -``` python -images = pipe(**get_inputs(batch_size=8)).images -image_grid(images, rows=2, cols=4) -``` +```python +images = pipeline(**get_inputs(batch_size=4)).images +image_grid(images) +``` -![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_7.png) +Unless you have a GPU with more RAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function: -Pretty impressive! We got some very high-quality image generations there. The 2nd image is my personal favorite, so I'll re-use this seed and see whether I can tweak the prompts slightly by using "oldest warrior", "old", "", and "young" instead of "old". +```python +pipeline.enable_attention_slicing() +``` -``` python -prompts = [ - "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", - "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", - "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", - "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", -] +Now try increasing the `batch_size` to 8! -generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))] # 1 because we want the 2nd image +```python +images = pipeline(**get_inputs(batch_size=8)).images +image_grid(images, rows=2, cols=4) +``` -images = pipe(prompt=prompts, generator=generator, num_inference_steps=25).images -image_grid(images) -``` +
+ +
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_8.png) +Whereas before you couldn't even generate a batch of 4 images, now you can generate a batch of 8 images at ~3.5 seconds per image! This is probably the fastest you can go on a T4 GPU without sacrificing quality. -The first picture looks nice! The eye movement slightly changed and looks nice. This finished up our 101-guide on how to use Stable Diffusion πŸ€—. +## Quality -For more information on optimization or other guides, I recommend taking a look at the following: +In the last two sections, you learned how to optimize the speed of your pipeline by using `fp16`, reducing the number of inference steps by using a more performant scheduler, and enabling attention slicing to reduce memory consumption. Now you're going to focus on how to improve the quality of generated images. -- [Blog post about Stable Diffusion](https://huggingface.co/blog/stable_diffusion): In-detail blog post explaining Stable Diffusion. -- [FlashAttention](https://huggingface.co/docs/diffusers/optimization/xformers): XFormers flash attention can optimize your model even further with more speed and memory improvements. -- [Dreambooth](https://huggingface.co/docs/diffusers/training/dreambooth) - Quickly customize the model by fine-tuning it. -- [General info on Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/overview) - Info on other tasks that are powered by Stable Diffusion. +### Better checkpoints + +The most obvious step is to use better checkpoints. The Stable Diffusion model is a good starting point, and since its official launch, several improved versions have also been released. However, using a newer version doesn't automatically mean you'll get better results. You'll still have to experiment with different checkpoints yourself, and do a little research (such as using [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) to get the best results. + +As the field grows, there are more and more high-quality checkpoints finetuned to produce certain styles. Try exploring the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) and [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) to find one you're interested in! + +### Better pipeline components + +You can also try replacing the current pipeline components with a newer version. Let's try loading the latest [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) from Stability AI into the pipeline, and generate some images: + +```python +from diffusers import AutoencoderKL + +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda") +pipeline.vae = vae +images = pipeline(**get_inputs(batch_size=8)).images +image_grid(images, rows=2, cols=4) +``` + +
+ +
+ +### Better prompt engineering + +The text prompt you use to generate an image is super important, so much so that it is called *prompt engineering*. Some considerations to keep during prompt engineering are: + +- How is the image or similar images of the one I want to generate stored on the internet? +- What additional detail can I give that steers the model towards the style I want? + +With this in mind, let's improve the prompt to include color and higher quality details: + +```python +prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes" +prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta" +``` + +Generate a batch of images with the new prompt: + +```python +images = pipeline(**get_inputs(batch_size=8)).images +image_grid(images, rows=2, cols=4) +``` + +
+ +
+ +Pretty impressive! Let's tweak the second image - corresponding to the `Generator` with a seed of `1` - a bit more by adding some text about the age of the subject: + +```python +prommpts = [ + "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", + "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", + "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", + "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", +] + +generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))] +images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images +image_grid(images) +``` + +
+ +
+ +## Next steps + +In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources: + +- Enable [xFormers](./optimization/xformers) memory efficient attention mechanism for faster speed and reduced memory consumption. +- Learn how in [PyTorch 2.0](./optimization/torch2.0), [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 2-9% faster inference speed. +- Many optimization techniques for inference are also included in this memory and speed [guide](./optimization/fp16), such as memory offloading. \ No newline at end of file diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index 623b9124f303..908355e496dc 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -237,7 +237,7 @@ python train_dreambooth_flax.py \ ## Finetuning with LoRA -You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, on DreamBooth. For more details, take a look at the [LoRA training](training/lora#dreambooth) guide. +You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, on DreamBooth. For more details, take a look at the [LoRA training](./lora#dreambooth) guide. ## Saving checkpoints while training @@ -457,11 +457,11 @@ If you have **`"accelerate>=0.16.0"`** installed, you can use the following code inference from an intermediate checkpoint: ```python -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline import torch model_id = "path_to_saved_model" -pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") prompt = "A photo of sks dog in a bucket" image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] diff --git a/docs/source/en/training/text2image.mdx b/docs/source/en/training/text2image.mdx index 81dbfba92146..4f57ccf94de0 100644 --- a/docs/source/en/training/text2image.mdx +++ b/docs/source/en/training/text2image.mdx @@ -74,25 +74,13 @@ To load a checkpoint to resume training, pass the argument `--resume_from_checkp Launch the [PyTorch training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) for a fine-tuning run on the [PokΓ©mon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset like this: -```bash -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export dataset_name="lambdalabs/pokemon-blip-captions" - -accelerate launch train_text_to_image.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --dataset_name=$dataset_name \ - --use_ema \ - --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ - --gradient_checkpointing \ - --mixed_precision="fp16" \ - --max_train_steps=15000 \ - --learning_rate=1e-05 \ - --max_grad_norm=1 \ - --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" -``` + +{"path": "../../../../examples/text_to_image/README.md", +"language": "bash", +"start-after": "accelerate_snippet_start", +"end-before": "accelerate_snippet_end", +"dedent": 0} + To finetune on your own dataset, prepare the dataset according to the format required by πŸ€— [Datasets](https://huggingface.co/docs/datasets/index). You can [upload your dataset to the Hub](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub), or you can [prepare a local folder with your files](https://huggingface.co/docs/datasets/image_dataset#imagefolder). @@ -167,6 +155,28 @@ python train_text_to_image_flax.py \ +## Training with Min-SNR weighting + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence +by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended +value when using it is 5.0. + +You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: + +* Training without the Min-SNR weighting strategy +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) + +For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. + +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + + + +Training with Min-SNR weighting strategy is only supported in PyTorch. + + + ## LoRA You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide. diff --git a/docs/source/en/training/text_inversion.mdx b/docs/source/en/training/text_inversion.mdx index 68c613849301..6e6971d7f119 100644 --- a/docs/source/en/training/text_inversion.mdx +++ b/docs/source/en/training/text_inversion.mdx @@ -157,24 +157,61 @@ If you're interested in following along with your model training progress, you c ## Inference -Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline`]. Make sure you include the `placeholder_token` in your prompt, in this case, it is ``. +Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline`]. + +The textual inversion script will by default only save the textual inversion embedding vector(s) that have +been added to the text encoder embedding matrix and consequently been trained. + + +πŸ’‘ The community has created a large library of different textual inversion embedding vectors, called [sd-concepts-library](https://huggingface.co/sd-concepts-library). +Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the libary. + + + +To load the textual inversion embeddings you first need to load the base model that was used when training +your textual inversion embedding vectors. Here we assume that [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5) +was used as a base model so we load it first: ```python from diffusers import StableDiffusionPipeline +import torch -model_id = "path-to-your-trained-model" +model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +``` -prompt = "A backpack" +Next, we need to load the textual inversion embedding vector which can be done via the [`TextualInversionLoaderMixin.load_textual_inversion`] +function. Here we'll load the embeddings of the "" example from before. +```python +pipe.load_textual_inversion("sd-concepts-library/cat-toy") +``` -image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] +Now we can run the pipeline making sure that the placeholder token `` is used in our prompt. +```python +prompt = "A backpack" + +image = pipe(prompt, num_inference_steps=50).images[0] image.save("cat-backpack.png") ``` + +The function [`TextualInversionLoaderMixin.load_textual_inversion`] can not only +load textual embedding vectors saved in Diffusers' format, but also embedding vectors +saved in [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) format. +To do so, you can first download an embedding vector from [civitAI](https://civitai.com/models/3036?modelVersionId=8387) +and then load it locally: +```python +pipe.load_textual_inversion("./charturnerv2.pt") +``` +Currently there is no `load_textual_inversion` function for Flax so one has to make sure the textual inversion +embedding vector is saved as part of the model after training. + +The model can then be run just like any other Flax model: + ```python import jax import numpy as np diff --git a/docs/source/en/tutorials/basic_training.mdx b/docs/source/en/tutorials/basic_training.mdx index 435de38d832f..52ce7c71fa68 100644 --- a/docs/source/en/tutorials/basic_training.mdx +++ b/docs/source/en/tutorials/basic_training.mdx @@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with πŸ€— Acce ... # Sample a random timestep for each image ... timesteps = torch.randint( -... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device +... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device ... ).long() ... # Add noise to the clean images according to the noise magnitude at each timestep diff --git a/docs/source/en/using-diffusers/conditional_image_generation.mdx b/docs/source/en/using-diffusers/conditional_image_generation.mdx index edd1cd926734..0b5c02415d87 100644 --- a/docs/source/en/using-diffusers/conditional_image_generation.mdx +++ b/docs/source/en/using-diffusers/conditional_image_generation.mdx @@ -10,22 +10,27 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Conditional Image Generation +# Conditional image generation + +[[open-in-colab]] + +Conditional image generation allows you to generate images from a text prompt. The text is converted into embeddings which are used to condition the model to generate an image from noise. The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference. -Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download. -You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads). -In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256): +Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) you would like to download. + +In this guide, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256): ```python >>> from diffusers import DiffusionPipeline >>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ``` + The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. -Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU. -You can move the generator object to GPU, just like you would in PyTorch. +Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU. +You can move the generator object to a GPU, just like you would in PyTorch: ```python >>> generator.to("cuda") @@ -37,10 +42,19 @@ Now you can use the `generator` on your text prompt: >>> image = generator("An image of a squirrel in Picasso style").images[0] ``` -The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). +The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object. -You can save the image by simply calling: +You can save the image by calling: ```python >>> image.save("image_of_squirrel_painting.png") ``` + +Try out the Spaces below, and feel free to play around with the guidance scale parameter to see how it affects the image quality! + + \ No newline at end of file diff --git a/docs/source/en/using-diffusers/contribute_pipeline.mdx b/docs/source/en/using-diffusers/contribute_pipeline.mdx index ce3f3e823252..8ee6d6ae4fb1 100644 --- a/docs/source/en/using-diffusers/contribute_pipeline.mdx +++ b/docs/source/en/using-diffusers/contribute_pipeline.mdx @@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): def __call__(self): image = torch.randn( - (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), ) timestep = 1 @@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): def __call__(self): image = torch.randn( - (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), ) timestep = 1 diff --git a/docs/source/en/using-diffusers/custom_pipeline_examples.mdx b/docs/source/en/using-diffusers/custom_pipeline_examples.mdx index fd37c6dc1a60..2dfa71f0d33c 100644 --- a/docs/source/en/using-diffusers/custom_pipeline_examples.mdx +++ b/docs/source/en/using-diffusers/custom_pipeline_examples.mdx @@ -45,11 +45,11 @@ The following code requires roughly 12GB of GPU RAM. ```python from diffusers import DiffusionPipeline -from transformers import CLIPFeatureExtractor, CLIPModel +from transformers import CLIPImageProcessor, CLIPModel import torch -feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") +feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16) diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.mdx b/docs/source/en/using-diffusers/custom_pipeline_overview.mdx index 9b3f92e1c363..934e639983d2 100644 --- a/docs/source/en/using-diffusers/custom_pipeline_overview.mdx +++ b/docs/source/en/using-diffusers/custom_pipeline_overview.mdx @@ -50,11 +50,11 @@ and passing pipeline modules directly. ```python from diffusers import DiffusionPipeline -from transformers import CLIPFeatureExtractor, CLIPModel +from transformers import CLIPImageProcessor, CLIPModel clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" -feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) +feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id) clip_model = CLIPModel.from_pretrained(clip_model_id) pipeline = DiffusionPipeline.from_pretrained( @@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline): @torch.no_grad() def __call__(self, batch_size: int = 1, num_inference_steps: int = 50): # Sample gaussian noise to begin loop - image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)) + image = torch.randn( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size) + ) image = image.to(self.device) diff --git a/docs/source/en/using-diffusers/depth2img.mdx b/docs/source/en/using-diffusers/depth2img.mdx index eace64c3109a..a4141644b006 100644 --- a/docs/source/en/using-diffusers/depth2img.mdx +++ b/docs/source/en/using-diffusers/depth2img.mdx @@ -10,9 +10,13 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Text-Guided Image-to-Image Generation +# Text-guided depth-to-image generation -The [`StableDiffusionDepth2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images as well as a `depth_map` to preserve the images' structure. If no `depth_map` is provided, the pipeline will automatically predict the depth via an integrated depth-estimation model. +[[open-in-colab]] + +The [`StableDiffusionDepth2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. In addition, you can also pass a `depth_map` to preserve the image structure. If no `depth_map` is provided, the pipeline automatically predicts the depth via an integrated [depth-estimation model](https://github.com/isl-org/MiDaS). + +Start by creating an instance of the [`StableDiffusionDepth2ImgPipeline`]: ```python import torch @@ -25,11 +29,28 @@ pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-2-depth", torch_dtype=torch.float16, ).to("cuda") +``` +Now pass your prompt to the pipeline. You can also pass a `negative_prompt` to prevent certain words from guiding how an image is generated: +```python url = "http://images.cocodataset.org/val2017/000000039769.jpg" init_image = Image.open(requests.get(url, stream=True).raw) prompt = "two tigers" n_prompt = "bad, deformed, ugly, bad anatomy" image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0] +image ``` + +| Input | Output | +|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------| +| | | + +Play around with the Spaces below and see if you notice a difference between generated images with and without a depth map! + + diff --git a/docs/source/en/using-diffusers/img2img.mdx b/docs/source/en/using-diffusers/img2img.mdx index 6ebe1f0633f0..71540fbf5dd9 100644 --- a/docs/source/en/using-diffusers/img2img.mdx +++ b/docs/source/en/using-diffusers/img2img.mdx @@ -10,11 +10,11 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Text-Guided Image-to-Image Generation +# Text-guided image-to-image generation [[open-in-colab]] -The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. This tutorial shows how to use it for text-guided image-to-image generation with Stable Diffusion model. +The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. Before you begin, make sure you have all the necessary libraries installed: @@ -22,27 +22,22 @@ Before you begin, make sure you have all the necessary libraries installed: !pip install diffusers transformers ftfy accelerate ``` -Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model. +Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model like [`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion). ```python import torch import requests from PIL import Image from io import BytesIO - from diffusers import StableDiffusionImg2ImgPipeline -``` -Load the pipeline: - -```python device = "cuda" -pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( +pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16).to( device ) ``` -Download an initial image and preprocess it so we can pass it to the pipeline: +Download and preprocess an initial image so you can pass it to the pipeline: ```python url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" @@ -53,61 +48,52 @@ init_image.thumbnail((768, 768)) init_image ``` -![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_8_output_0.jpeg) - -Define the prompt and run the pipeline: - -```python -prompt = "A fantasy landscape, trending on artstation" -``` +
+ +
-`strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. +πŸ’‘ `strength` is a value between 0.0 and 1.0 that controls the amount of noise added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. -Let's generate two images with same pipeline and seed, but with different values for `strength`: +Define the prompt (for this checkpoint finetuned on Ghibli-style art, you need to prefix the prompt with the `ghibli style` tokens) and run the pipeline: ```python +prompt = "ghibli style, a fantasy landscape with castles" generator = torch.Generator(device=device).manual_seed(1024) image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0] -``` - -```python image ``` -![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_13_output_0.jpeg) +
+ +
- -```python -image = pipe(prompt=prompt, image=init_image, strength=0.5, guidance_scale=7.5, generator=generator).images[0] -image -``` - -![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_14_output_1.jpeg) - - -As you can see, when using a lower value for `strength`, the generated image is more closer to the original `image`. - -Now let's use a different scheduler - [LMSDiscreteScheduler](https://huggingface.co/docs/diffusers/api/schedulers#diffusers.LMSDiscreteScheduler): +You can also try experimenting with a different scheduler to see how that affects the output: ```python from diffusers import LMSDiscreteScheduler lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config) pipe.scheduler = lms -``` - -```python generator = torch.Generator(device=device).manual_seed(1024) image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0] -``` - -```python image ``` -![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_19_output_0.jpeg) +
+ +
+ +Check out the Spaces below, and try generating images with different values for `strength`. You'll notice that using lower values for `strength` produces images that are more similar to the original image. + +Feel free to also switch the scheduler to the [`LMSDiscreteScheduler`] and see how that affects the output. + diff --git a/docs/source/en/using-diffusers/inpaint.mdx b/docs/source/en/using-diffusers/inpaint.mdx index 1fcd0e6a5142..41a6d4b7e1b2 100644 --- a/docs/source/en/using-diffusers/inpaint.mdx +++ b/docs/source/en/using-diffusers/inpaint.mdx @@ -10,9 +10,13 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Text-Guided Image-Inpainting +# Text-guided image-inpainting -The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image by providing a mask and a text prompt. It uses a version of Stable Diffusion specifically trained for in-painting tasks. +[[open-in-colab]] + +The [`StableDiffusionInpaintPipeline`] allows you to edit specific parts of an image by providing a mask and a text prompt. It uses a version of Stable Diffusion, like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) specifically trained for inpainting tasks. + +Get started by loading an instance of the [`StableDiffusionInpaintPipeline`]: ```python import PIL @@ -22,7 +26,16 @@ from io import BytesIO from diffusers import StableDiffusionInpaintPipeline +pipeline = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + torch_dtype=torch.float16, +) +pipeline = pipeline.to("cuda") +``` + +Download an image and a mask of a dog which you'll eventually replace: +```python def download_image(url): response = requests.get(url) return PIL.Image.open(BytesIO(response.content)).convert("RGB") @@ -33,24 +46,31 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512)) +``` -pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", - torch_dtype=torch.float16, -) -pipe = pipe.to("cuda") +Now you can create a prompt to replace the mask with something else: +```python prompt = "Face of a yellow cat, high resolution, sitting on a park bench" image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` -`image` | `mask_image` | `prompt` | **Output** | +`image` | `mask_image` | `prompt` | output | :-------------------------:|:-------------------------:|:-------------------------:|-------------------------:| drawing | drawing | ***Face of a yellow cat, high resolution, sitting on a park bench*** | drawing | -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) - -A previous experimental implementation of in-painting used a different, lower-quality process. To ensure backwards compatibility, loading a pretrained pipeline that doesn't contain the new model will still apply the old in-painting method. + +A previous experimental implementation of inpainting used a different, lower-quality process. To ensure backwards compatibility, loading a pretrained pipeline that doesn't contain the new model will still apply the old inpainting method. + + +Check out the Spaces below to try out image inpainting yourself! + + diff --git a/docs/source/en/using-diffusers/loading.mdx b/docs/source/en/using-diffusers/loading.mdx index c41315c995de..24dd1dd04cd1 100644 --- a/docs/source/en/using-diffusers/loading.mdx +++ b/docs/source/en/using-diffusers/loading.mdx @@ -10,20 +10,28 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Loading +# Load pipelines, models, and schedulers -A core premise of the diffusers library is to make diffusion models **as accessible as possible**. -Accessibility is therefore achieved by providing an API to load complete diffusion pipelines as well as individual components with a single line of code. +Having an easy way to use a diffusion system for inference is essential to 🧨 Diffusers. Diffusion systems often consist of multiple components like parameterized models, tokenizers, and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API, while remaining flexible enough to be adapted for other use cases, such as loading each component individually as building blocks to assemble your own diffusion system. -In the following we explain in-detail how to easily load: +Everything you need for inference or training is accessible with the `from_pretrained()` method. -- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`] -- *Diffusion Models* via [`ModelMixin.from_pretrained`] -- *Schedulers* via [`SchedulerMixin.from_pretrained`] +This guide will show you how to load: -## Loading pipelines +- pipelines from the Hub and locally +- different components into a pipeline +- checkpoint variants such as different floating point types or non-exponential mean averaged (EMA) weights +- models and schedulers -The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [Runway's Stable Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5). +## Diffusion Pipeline + + + +πŸ’‘ Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you interested in learning in more detail about how the [`DiffusionPipeline`] class works. + + + +The [`DiffusionPipeline`] class is the simplest and most generic way to load any diffusion model from the [Hub](https://huggingface.co/models?library=diffusers). The [`DiffusionPipeline.from_pretrained`] method automatically detects the correct pipeline class from the checkpoint, downloads and caches all the required configuration and weight files, and returns a pipeline instance ready for inference. ```python from diffusers import DiffusionPipeline @@ -32,10 +40,7 @@ repo_id = "runwayml/stable-diffusion-v1-5" pipe = DiffusionPipeline.from_pretrained(repo_id) ``` -Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`StableDiffusionPipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `pipe`. -The pipeline instance can then be called using [`StableDiffusionPipeline.__call__`] (i.e., `pipe("image of a astronaut riding a horse")`) for text-to-image generation. - -Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing: +You can also load a checkpoint with it's specific pipeline class. The example above loaded a Stable Diffusion model; to get the same result, use the [`StableDiffusionPipeline`] class: ```python from diffusers import StableDiffusionPipeline @@ -44,10 +49,7 @@ repo_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(repo_id) ``` - - -Many checkpoints, such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for multiple tasks, *e.g.* *text-to-image* or *image-to-image*. -If you want to use those checkpoints for a task that is different from the default one, you have to load it directly from the corresponding task-specific pipeline class: +A checkpoint (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) or [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) may also be used for more than one task, like text-to-image or image-to-image. To differentiate what task you want to use the checkpoint for, you have to load it directly with it's corresponding task-specific pipeline class: ```python from diffusers import StableDiffusionImg2ImgPipeline @@ -56,101 +58,47 @@ repo_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id) ``` - - - -Diffusion pipelines like `StableDiffusionPipeline` or `StableDiffusionImg2ImgPipeline` consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vae"` and `"text_encoder"`, tokenizers or schedulers. -These components often interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work). -The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later. +### Local pipeline - - -### Loading pipelines locally - -If you prefer to have complete control over the pipeline and its corresponding files or, as said before, if you want to use pipelines that require an access request without having to be connected to the Hugging Face Hub, -we recommend loading pipelines locally. - -To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for -[Runway's Stable Diffusion Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5). - -First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main): - -``` -git lfs install -git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 -``` - -The command above will create a local folder called `./stable-diffusion-v1-5` on your disk. -Now, all you have to do is to simply pass the local folder path to `from_pretrained`: - -```python -from diffusers import DiffusionPipeline - -repo_id = "./stable-diffusion-v1-5" stable_diffusion = DiffusionPipeline.from_pretrained(repo_id) +stable_diffusion.scheduler.compatibles ``` -If `repo_id` is a local path, as it is the case here, [`DiffusionPipeline.from_pretrained`] will automatically detect it and therefore not try to download any files from the Hub. -While we usually recommend to load weights directly from the Hub to be certain to stay up to date with the newest changes, loading pipelines locally should be preferred if one -wants to stay anonymous, self-contained applications, etc... - -### Loading customized pipelines +Let's use the [`SchedulerMixin.from_pretrained`] method to replace the default [`PNDMScheduler`] with a more performant scheduler, [`EulerDiscreteScheduler`]. The `subfolder="scheduler"` argument is required to load the scheduler configuration from the correct [subfolder](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) of the pipeline repository. -Advanced users that want to load customized versions of diffusion pipelines can do so by swapping any of the default components, *e.g.* the scheduler, with other scheduler classes. -A classical use case of this functionality is to swap the scheduler. [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) uses the [`PNDMScheduler`] by default which is generally not the most performant scheduler. Since the release -of stable diffusion, multiple improved schedulers have been published. To use those, the user has to manually load their preferred scheduler and pass it into [`DiffusionPipeline.from_pretrained`]. - -*E.g.* to use [`EulerDiscreteScheduler`] or [`DPMSolverMultistepScheduler`] to have a better quality vs. generation speed trade-off for inference, one could load them as follows: +Then you can pass the new [`EulerDiscreteScheduler`] instance to the `scheduler` argument in [`DiffusionPipeline`]: ```python from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler @@ -158,31 +106,24 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis repo_id = "runwayml/stable-diffusion-v1-5" scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") -# or -# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler) ``` -Three things are worth paying attention to here. -- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`] -- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) -- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`] - -Not only the scheduler components can be customized for diffusion pipelines; in theory, all components of a pipeline can be customized. In practice, however, it often only makes sense to switch out a component that has **compatible** alternatives to what the pipeline expects. -Many scheduler classes are compatible with each other as can be seen [here](https://github.com/huggingface/diffusers/blob/0dd8c6b4dbab4069de9ed1cafb53cbd495873879/src/diffusers/schedulers/scheduling_ddim.py#L112). This is not always the case for other components, such as the `"unet"`. +### Safety checker -One special case that can also be customized is the `"safety_checker"` of stable diffusion. If you believe the safety checker doesn't serve you any good, you can simply disable it by passing `None`: +Diffusion models like Stable Diffusion can generate harmful content, which is why 🧨 Diffusers has a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to check generated outputs against known hardcoded NSFW content. If you'd like to disable the safety checker for whatever reason, pass `None` to the `safety_checker` argument: ```python -from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler +from diffusers import DiffusionPipeline +repo_id = "runwayml/stable-diffusion-v1-5" stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None) ``` -Another common use case is to reuse the same components in multiple pipelines, *e.g.* the weights and configurations of [`"runwayml/stable-diffusion-v1-5"`](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for both [`StableDiffusionPipeline`] and [`StableDiffusionImg2ImgPipeline`] and we might not want to -use the exact same weights into RAM twice. In this case, customizing all the input instances would help us -to only load the weights into RAM once: +### Reuse components across pipelines + +You can also reuse the same components in multiple pipelines to avoid loading the weights into RAM twice. Use the [`~DiffusionPipeline.components`] method to save the components: ```python from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline @@ -191,231 +132,216 @@ model_id = "runwayml/stable-diffusion-v1-5" stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id) components = stable_diffusion_txt2img.components +``` + +Then you can pass the `components` to another pipeline without reloading the weights into RAM: -# weights are not reloaded into RAM +```py stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components) ``` -Note how the above code snippet makes use of [`DiffusionPipeline.components`]. +You can also pass the components individually to the pipeline if you want more flexibility over which components to reuse or disable. For example, to reuse the same components in the text-to-image pipeline, except for the safety checker and feature extractor, in the image-to-image pipeline: -### Loading variants +```py +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline + +model_id = "runwayml/stable-diffusion-v1-5" +stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id) +stable_diffusion_img2img = StableDiffusionImg2ImgPipeline( + vae=stable_diffusion_txt2img.vae, + text_encoder=stable_diffusion_txt2img.text_encoder, + tokenizer=stable_diffusion_txt2img.tokenizer, + unet=stable_diffusion_txt2img.unet, + scheduler=stable_diffusion_txt2img.scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, +) +``` -Diffusion Pipeline checkpoints can offer variants of the "main" diffusion pipeline checkpoint. -Such checkpoint variants are usually variations of the checkpoint that have advantages for specific use-cases and that are so similar to the "main" checkpoint that they **should not** be put in a new checkpoint. -A variation of a checkpoint has to have **exactly** the same serialization format and **exactly** the same model structure, including all weights having the same tensor shapes. +## Checkpoint variants -Examples of variations are different floating point types and non-ema weights. I.e. "fp16", "bf16", and "no_ema" are common variations. +A checkpoint variant is usually a checkpoint where it's weights are: -#### Let's first talk about whats **not** checkpoint variant, +- Stored in a different floating point type for lower precision and lower storage, such as [`torch.float16`](https://pytorch.org/docs/stable/tensors.html#data-types), because it only requires half the bandwidth and storage to download. You can't use this variant if you're continuing training or using a CPU. +- Non-exponential mean averaged (EMA) weights which shouldn't be used for inference. You should use these to continue finetuning a model. -Checkpoint variants do **not** include different serialization formats (such as [safetensors](https://huggingface.co/docs/diffusers/main/en/using-diffusers/using_safetensors)) as weights in different serialization formats are -identical to the weights of the "main" checkpoint, just loaded in a different framework. + -Also variants do not correspond to different model structures, *e.g.* [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is not a variant of [stable-diffusion-2-0](https://huggingface.co/stabilityai/stable-diffusion-2) since the model structure is different (Stable Diffusion 1-5 uses a different `CLIPTextModel` compared to Stable Diffusion 2.0). +πŸ’‘ When the checkpoints have identical model structures, but they were trained on different datasets and with a different training setup, they should be stored in separate repositories instead of variations (for example, [`stable-diffusion-v1-4`] and [`stable-diffusion-v1-5`]). -Pipeline checkpoints that are identical in model structure, but have been trained on different datasets, trained with vastly different training setups and thus correspond to different official releases (such as [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)) should probably be stored in individual repositories instead of as variations of each other. + -#### So what are checkpoint variants then? +Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using-diffusers/using_safetensors)), model structure, and weights have identical tensor shapes. -Checkpoint variants usually consist of the checkpoint stored in "*low-precision, low-storage*" dtype so that less bandwith is required to download them, or of *non-exponential-averaged* weights that shall be used when continuing fine-tuning from the checkpoint. -Both use cases have clear advantages when their weights are considered variants: they share the same serialization format as the reference weights, and they correspond to a specialization of the "main" checkpoint which does not warrant a new model repository. -A checkpoint stored in [torch's half-precision / float16 format](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) requires only half the bandwith and storage when downloading the checkpoint, -**but** cannot be used when continuing training or when running the checkpoint on CPU. -Similarly the *non-exponential-averaged* (or non-EMA) version of the checkpoint should be used when continuing fine-tuning of the model checkpoint, **but** should not be used when using the checkpoint for inference. +| **checkpoint type** | **weight name** | **argument for loading weights** | +|---------------------|-------------------------------------|----------------------------------| +| original | diffusion_pytorch_model.bin | | +| floating point | diffusion_pytorch_model.fp16.bin | `variant`, `torch_dtype` | +| non-EMA | diffusion_pytorch_model.non_ema.bin | `variant` | -#### How to save and load variants +There are two important arguments to know for loading variants: -Saving a diffusion pipeline as a variant can be done by providing [`DiffusionPipeline.save_pretrained`] with the `variant` argument. -The `variant` extends the weight name by the provided variation, by changing the default weight name from `diffusion_pytorch_model.bin` to `diffusion_pytorch_model.{variant}.bin` or from `diffusion_pytorch_model.safetensors` to `diffusion_pytorch_model.{variant}.safetensors`. By doing so, one creates a variant of the pipeline checkpoint that can be loaded **instead** of the "main" pipeline checkpoint. +- `torch_dtype` defines the floating point precision of the loaded checkpoints. For example, if you want to save bandwidth by loading a `fp16` variant, you should specify `torch_dtype=torch.float16` to *convert the weights* to `fp16`. Otherwise, the `fp16` weights are converted to the default `fp32` precision. You can also load the original checkpoint without defining the `variant` argument, and convert it to `fp16` with `torch_dtype=torch.float16`. In this case, the default `fp32` weights are downloaded first, and then they're converted to `fp16` after loading. -Let's have a look at how we could create a float16 variant of a pipeline. First, we load -the "main" variant of a checkpoint (stored in `float32` precision) into mixed precision format, using `torch_dtype=torch.float16`. +- `variant` defines which files should be loaded from the repository. For example, if you want to load a `non_ema` variant from the [`diffusers/stable-diffusion-variants`](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main/unet) repository, you should specify `variant="non_ema"` to download the `non_ema` files. -```py +```python from diffusers import DiffusionPipeline -import torch -pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +# load fp16 variant +stable_diffusion = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16 +) +# load non_ema variant +stable_diffusion = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema") ``` -Now all model components of the pipeline are stored in half-precision dtype. We can now save the -pipeline under a `"fp16"` variant as follows: +To save a checkpoint stored in a different floating point type or as a non-EMA variant, use the [`DiffusionPipeline.save_pretrained`] method and specify the `variant` argument. You should try and save a variant to the same folder as the original checkpoint, so you can load both from the same folder: -```py -pipe.save_pretrained("./stable-diffusion-v1-5", variant="fp16") +```python +from diffusers import DiffusionPipeline + +# save as fp16 variant +stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16") +# save as non-ema variant +stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema") ``` -If we don't save into an existing `stable-diffusion-v1-5` folder the new folder would look as follows: +If you don't save the variant to an existing folder, you must specify the `variant` argument otherwise it'll throw an `Exception` because it can't find the original checkpoint: -``` -stable-diffusion-v1-5 -β”œβ”€β”€ feature_extractor -β”‚Β Β  └── preprocessor_config.json -β”œβ”€β”€ model_index.json -β”œβ”€β”€ safety_checker -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  └── pytorch_model.fp16.bin -β”œβ”€β”€ scheduler -β”‚Β Β  └── scheduler_config.json -β”œβ”€β”€ text_encoder -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  └── pytorch_model.fp16.bin -β”œβ”€β”€ tokenizer -β”‚Β Β  β”œβ”€β”€ merges.txt -β”‚Β Β  β”œβ”€β”€ special_tokens_map.json -β”‚Β Β  β”œβ”€β”€ tokenizer_config.json -β”‚Β Β  └── vocab.json -β”œβ”€β”€ unet -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  └── diffusion_pytorch_model.fp16.bin -└── vae - β”œβ”€β”€ config.json - └── diffusion_pytorch_model.fp16.bin +```python +# πŸ‘Ž this won't work +stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16) +# πŸ‘ this works +stable_diffusion = DiffusionPipeline.from_pretrained( + "./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16 +) ``` -As one can see, all model files now have a `.fp16.bin` extension instead of just `.bin`. -The variant now has to be loaded by also passing a `variant="fp16"` to [`DiffusionPipeline.from_pretrained`], e.g.: + -We can now both download the "main" and the "fp16" variant from the Hub. Both: +## Models -```py -pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants") +Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of redownloading them. + +Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for `runwayml/stable-diffusion-v1-5` are stored in the [`unet`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet) subfolder: + +```python +from diffusers import UNet2DConditionModel + +repo_id = "runwayml/stable-diffusion-v1-5" +model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet") ``` -and +Or directly from a repository's [directory](https://huggingface.co/google/ddpm-cifar10-32/tree/main): -```py -pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="fp16") +```python +from diffusers import UNet2DModel + +repo_id = "google/ddpm-cifar10-32" +model = UNet2DModel.from_pretrained(repo_id) ``` -work. +You can also load and save model variants by specifying the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`]: - +```python +from diffusers import UNet2DConditionModel + +model = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non-ema") +model.save_pretrained("./local-unet", variant="non-ema") +``` -Note that Diffusers never downloads more checkpoints than needed. E.g. when downloading -the "main" variant, none of the "fp16.bin" files are downloaded and cached. -Only when the user specifies `variant="fp16"` are those files downloaded and cached. +## Schedulers - +Schedulers are loaded from the [`SchedulerMixin.from_pretrained`] method, and unlike models, schedulers are **not parameterized** or **trained**; they are defined by a configuration file. -Finally, there are cases where only some of the checkpoint files of the pipeline are of a certain -variation. E.g. it's usually only the UNet checkpoint that has both a *exponential-mean-averaged* (EMA) and a *non-exponential-mean-averaged* (non-EMA) version. All other model components, e.g. the text encoder, safety checker or variational auto-encoder usually don't have such a variation. -In such a case, one would upload just the UNet's checkpoint file with a `non_ema` version format (as done [here](https://huggingface.co/diffusers/stable-diffusion-variants/blob/main/unet/diffusion_pytorch_model.non_ema.bin)) and upon calling: +Loading schedulers does not consume any significant amount of memory and the same configuration file can be used for a variety of different schedulers. +For example, the following schedulers are compatible with [`StableDiffusionPipeline`] which means you can load the same scheduler configuration file in any of these classes: ```python -pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="non_ema") -``` +from diffusers import StableDiffusionPipeline +from diffusers import ( + DDPMScheduler, + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) -the model will use only the "non_ema" checkpoint variant if it is available - otherwise it'll load the -"main" variation. In the above example, `variant="non_ema"` would therefore download the following file structure: +repo_id = "runwayml/stable-diffusion-v1-5" -``` -β”œβ”€β”€ feature_extractor -β”‚Β Β  └── preprocessor_config.json -β”œβ”€β”€ model_index.json -β”œβ”€β”€ safety_checker -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  β”œβ”€β”€ pytorch_model.bin -β”œβ”€β”€ scheduler -β”‚Β Β  └── scheduler_config.json -β”œβ”€β”€ text_encoder -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  β”œβ”€β”€ pytorch_model.bin -β”œβ”€β”€ tokenizer -β”‚Β Β  β”œβ”€β”€ merges.txt -β”‚Β Β  β”œβ”€β”€ special_tokens_map.json -β”‚Β Β  β”œβ”€β”€ tokenizer_config.json -β”‚Β Β  └── vocab.json -β”œβ”€β”€ unet -β”‚Β Β  β”œβ”€β”€ config.json -β”‚Β Β  └── diffusion_pytorch_model.non_ema.bin -└── vae - β”œβ”€β”€ config.json - β”œβ”€β”€ diffusion_pytorch_model.bin -``` +ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") +ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler") +pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") -In a nutshell, using `variant="{variant}"` will download all files that match the `{variant}` and if for a model component such a file variant is not present it will download the "main" variant. If neither a "main" or `{variant}` variant is available, an error will the thrown. +# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler` +pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) +``` -### How does loading work? +## DiffusionPipeline explained As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things: -- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files. -- Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file. -The underlying folder structure of diffusion pipelines corresponds 1-to-1 to their corresponding class instances, *e.g.* [`StableDiffusionPipeline`] for [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5). -This can be better understood by looking at an example. Let's load a pipeline class instance `pipe` and print it: +- Download the latest version of the folder structure required for inference and cache it. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] reuses the cache and won't redownload the files. +- Load the cached weights into the correct pipeline [class](./api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it. + +The pipelines underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5). ```python from diffusers import DiffusionPipeline repo_id = "runwayml/stable-diffusion-v1-5" -pipe = DiffusionPipeline.from_pretrained(repo_id) -print(pipe) +pipeline = DiffusionPipeline.from_pretrained(repo_id) +print(pipeline) ``` -*Output*: -``` +You'll see pipeline is an instance of [`StableDiffusionPipeline`], which consists of seven components: + +- `"feature_extractor"`: a [`~transformers.CLIPFeatureExtractor`] from πŸ€— Transformers. +- `"safety_checker"`: a [component](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) for screening against harmful content. +- `"scheduler"`: an instance of [`PNDMScheduler`]. +- `"text_encoder"`: a [`~transformers.CLIPTextModel`] from πŸ€— Transformers. +- `"tokenizer"`: a [`~transformers.CLIPTokenizer`] from πŸ€— Transformers. +- `"unet"`: an instance of [`UNet2DConditionModel`]. +- `"vae"` an instance of [`AutoencoderKL`]. + +```json StableDiffusionPipeline { "feature_extractor": [ "transformers", - "CLIPFeatureExtractor" + "CLIPImageProcessor" ], "safety_checker": [ "stable_diffusion", @@ -444,16 +370,7 @@ StableDiffusionPipeline { } ``` -First, we see that the official pipeline is the [`StableDiffusionPipeline`], and second we see that the `StableDiffusionPipeline` consists of 7 components: -- `"feature_extractor"` of class `CLIPFeatureExtractor` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPFeatureExtractor). -- `"safety_checker"` as defined [here](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32). -- `"scheduler"` of class [`PNDMScheduler`]. -- `"text_encoder"` of class `CLIPTextModel` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel). -- `"tokenizer"` of class `CLIPTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). -- `"unet"` of class [`UNet2DConditionModel`]. -- `"vae"` of class [`AutoencoderKL`]. - -Let's now compare the pipeline instance to the folder structure of the model repository `runwayml/stable-diffusion-v1-5`. Looking at the folder structure of [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) on the Hub and excluding model and saving format variants, we can see it matches 1-to-1 the printed out instance of `StableDiffusionPipeline` above: +Compare the components of the pipeline instance to the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) folder structure, and you'll see there is a separate folder for each of the components in the repository: ``` . @@ -481,19 +398,39 @@ Let's now compare the pipeline instance to the folder structure of the model rep β”œβ”€β”€ diffusion_pytorch_model.bin ``` -Each attribute of the instance of `StableDiffusionPipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"feature_extractor"`, `"safety_checker"`, `"scheduler"`, `"text_encoder"`, `"tokenizer"`, `"unet"`, `"vae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both: -- which pipeline class should be loaded, and -- what sub-classes from which library are stored in which subfolders - -In the case of `runwayml/stable-diffusion-v1-5` the `model_index.json` is therefore defined as follows: +You can access each of the components of the pipeline as an attribute to view its configuration: +```py +pipeline.tokenizer +CLIPTokenizer( + name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer", + vocab_size=49408, + model_max_length=77, + is_fast=False, + padding_side="right", + truncation_side="right", + special_tokens={ + "bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), + "eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), + "unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), + "pad_token": "<|endoftext|>", + }, +) ``` + +Every pipeline expects a `model_index.json` file that tells the [`DiffusionPipeline`]: + +- which pipeline class to load from `_class_name` +- which version of 🧨 Diffusers was used to create the model in `_diffusers_version` +- what components from which library are stored in the subfolders (`name` corresponds to the component and subfolder name, `library` corresponds to the name of the library to load the class from, and `class` corresponds to the class name) + +```json { "_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.6.0", "feature_extractor": [ "transformers", - "CLIPFeatureExtractor" + "CLIPImageProcessor" ], "safety_checker": [ "stable_diffusion", @@ -520,138 +457,4 @@ In the case of `runwayml/stable-diffusion-v1-5` the `model_index.json` is theref "AutoencoderKL" ] } -``` - -- `_class_name` tells `DiffusionPipeline` which pipeline class should be loaded. -- `_diffusers_version` can be useful to know under which `diffusers` version this model was created. -- Every component of the pipeline is then defined under the form: -``` -"name" : [ - "library", - "class" -] -``` - - The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)) - - The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded - - The `"class"` field corresponds to the name of the class, *e.g.* [`CLIPTokenizer`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer) or [`UNet2DConditionModel`] - - - -## Loading models - -Models as defined under [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) can be loaded via the [`ModelMixin.from_pretrained`] function. The API is very similar the [`DiffusionPipeline.from_pretrained`] and works in the same way: -- Download the latest version of the model weights and configuration with `diffusers` and cache them. If the latest files are available in the local cache, [`ModelMixin.from_pretrained`] will simply reuse the cache and **not** re-download the files. -- Load the cached weights into the _defined_ model class - one of [the existing model classes](./api/models) - and return an instance of the class. - -In constrast to [`DiffusionPipeline.from_pretrained`], models rely on fewer files that usually don't require a folder structure, but just a `diffusion_pytorch_model.bin` and `config.json` file. - -Let's look at an example: - -```python -from diffusers import UNet2DConditionModel - -repo_id = "runwayml/stable-diffusion-v1-5" -model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet") -``` - -Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet). - -As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]: - -```python -from diffusers import DiffusionPipeline - -repo_id = "runwayml/stable-diffusion-v1-5" -pipe = DiffusionPipeline.from_pretrained(repo_id, unet=model) -``` - -If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't -need to pass a `subfolder` argument: - -```python -from diffusers import UNet2DModel - -repo_id = "google/ddpm-cifar10-32" -model = UNet2DModel.from_pretrained(repo_id) -``` - -As motivated in [How to save and load variants?](#how-to-save-and-load-variants), models can load and -save variants. To load a model variant, one should pass the `variant` function argument to [`ModelMixin.from_pretrained`]. Analogous, to save a model variant, one should pass the `variant` function argument to [`ModelMixin.save_pretrained`]: - -```python -from diffusers import UNet2DConditionModel - -model = UNet2DConditionModel.from_pretrained( - "diffusers/stable-diffusion-variants", subfolder="unet", variant="non_ema" -) -model.save_pretrained("./local-unet", variant="non_ema") -``` - -## Loading schedulers - -Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. -For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case. - -In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers. -For example, all of: - -- [`DDPMScheduler`] -- [`DDIMScheduler`] -- [`PNDMScheduler`] -- [`LMSDiscreteScheduler`] -- [`EulerDiscreteScheduler`] -- [`EulerAncestralDiscreteScheduler`] -- [`DPMSolverMultistepScheduler`] - -are compatible with [`StableDiffusionPipeline`] and therefore the same scheduler configuration file can be loaded in any of those classes: - -```python -from diffusers import StableDiffusionPipeline -from diffusers import ( - DDPMScheduler, - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, -) - -repo_id = "runwayml/stable-diffusion-v1-5" - -ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") -ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler") -pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler") -lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") -euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") -euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") -dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") - -# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler` -pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) -``` +``` \ No newline at end of file diff --git a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.mdx b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.mdx new file mode 100644 index 000000000000..e0332fdc6496 --- /dev/null +++ b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.mdx @@ -0,0 +1,250 @@ +# 🧨 Stable Diffusion in JAX / Flax ! + +[[open-in-colab]] + +πŸ€— Hugging Face [Diffusers](https://github.com/huggingface/diffusers) supports Flax since version `0.5.1`! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. + +This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to [this notebook](https://huggingface.co/docs/diffusers/stable_diffusion). + +First, make sure you are using a TPU backend. If you are running this notebook in Colab, select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting. + +Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel. + +## Setup + +First make sure diffusers is installed. + +```bash +!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy +!pip install diffusers +``` + +```python +import jax.tools.colab_tpu + +jax.tools.colab_tpu.setup_tpu() +import jax +``` + +```python +num_devices = jax.device_count() +device_type = jax.devices()[0].device_kind + +print(f"Found {num_devices} JAX devices of type {device_type}.") +assert ( + "TPU" in device_type +), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator" +``` + +```python out +Found 8 JAX devices of type Cloud TPU. +``` + +Then we import all the dependencies. + +```python +import numpy as np +import jax +import jax.numpy as jnp + +from pathlib import Path +from jax import pmap +from flax.jax_utils import replicate +from flax.training.common_utils import shard +from PIL import Image + +from huggingface_hub import notebook_login +from diffusers import FlaxStableDiffusionPipeline +``` + +## Model Loading + +TPU devices support `bfloat16`, an efficient half-float type. We'll use it for our tests, but you can also use `float32` to use full precision instead. + +```python +dtype = jnp.bfloat16 +``` + +Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a `bf16` version of the weights, which leads to type warnings that you can safely ignore. + +```python +pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=dtype, +) +``` + +## Inference + +Since TPUs usually have 8 devices working in parallel, we'll replicate our prompt as many times as devices we have. Then we'll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we'll get 8 images in the same amount of time it takes for one chip to generate a single one. + +After replicating the prompt, we obtain the tokenized text ids by invoking the `prepare_inputs` function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model. + +```python +prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" +prompt = [prompt] * jax.device_count() +prompt_ids = pipeline.prepare_inputs(prompt) +prompt_ids.shape +``` + +```python out +(8, 77) +``` + +### Replication and parallelization + +Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using `flax.jax_utils.replicate`, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`. + +```python +p_params = replicate(params) +``` + +```python +prompt_ids = shard(prompt_ids) +prompt_ids.shape +``` + +```python out +(8, 1, 77) +``` + +That shape means that each one of the `8` devices will receive as an input a `jnp` array with shape `(1, 77)`. `1` is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than `1` if we wanted to generate multiple images (per chip) at once. + +We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices. + +The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we'll get the exact same results. Feel free to use different seeds when exploring results later in the notebook. + +```python +def create_key(seed=0): + return jax.random.PRNGKey(seed) +``` + +We obtain a rng and then "split" it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible. + +```python +rng = create_key(0) +rng = jax.random.split(rng, jax.device_count()) +``` + +JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn't be able to take advantage of the optimized speed. + +The Flax pipeline can compile the code for us if we pass `jit = True` as an argument. It will also ensure that the model runs in parallel in the 8 available devices. + +The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about **`7s`** for future inference runs. + +``` +%%time +images = pipeline(prompt_ids, p_params, rng, jit=True)[0] +``` + +```python out +CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s +Wall time: 1min 29s +``` + +The returned array has shape `(8, 1, 512, 512, 3)`. We reshape it to get rid of the second dimension and obtain 8 images of `512 Γ— 512 Γ— 3` and then convert them to PIL. + +```python +images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) +images = pipeline.numpy_to_pil(images) +``` + +### Visualization + +Let's create a helper function to display images in a grid. + +```python +def image_grid(imgs, rows, cols): + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid +``` + +```python +image_grid(images, 2, 4) +``` + +![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg) + + +## Using different prompts + +We don't have to replicate the _same_ prompt in all the devices. We can do whatever we want: generate 2 prompts 4 times each, or even generate 8 different prompts at once. Let's do that! + +First, we'll refactor the input preparation code into a handy function: + +```python +prompts = [ + "Labrador in the style of Hokusai", + "Painting of a squirrel skating in New York", + "HAL-9000 in the style of Van Gogh", + "Times Square under water, with fish and a dolphin swimming around", + "Ancient Roman fresco showing a man working on his laptop", + "Close-up photograph of young black woman against urban background, high quality, bokeh", + "Armchair in the shape of an avocado", + "Clown astronaut in space, with Earth in the background", +] +``` + +```python +prompt_ids = pipeline.prepare_inputs(prompts) +prompt_ids = shard(prompt_ids) + +images = pipeline(prompt_ids, p_params, rng, jit=True).images +images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) +images = pipeline.numpy_to_pil(images) + +image_grid(images, 2, 4) +``` + +![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg) + + +## How does parallelization work? + +We said before that the `diffusers` Flax pipeline automatically compiles the model and runs it in parallel on all available devices. We'll now briefly look inside that process to show how it works. + +JAX parallelization can be done in multiple ways. The easiest one revolves around using the `jax.pmap` function to achieve single-program, multiple-data (SPMD) parallelization. It means we'll run several copies of the same code, each on different data inputs. More sophisticated approaches are possible, we invite you to go over the [JAX documentation](https://jax.readthedocs.io/en/latest/index.html) and the [`pjit` pages](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit) to explore this topic if you are interested! + +`jax.pmap` does two things for us: +- Compiles (or `jit`s) the code, as if we had invoked `jax.jit()`. This does not happen when we call `pmap`, but the first time the pmapped function is invoked. +- Ensures the compiled code runs in parallel in all the available devices. + +To show how it works we `pmap` the `_generate` method of the pipeline, which is the private method that runs generates images. Please, note that this method may be renamed or removed in future releases of `diffusers`. + +```python +p_generate = pmap(pipeline._generate) +``` + +After we use `pmap`, the prepared function `p_generate` will conceptually do the following: +* Invoke a copy of the underlying function `pipeline._generate` in each device. +* Send each device a different portion of the input arguments. That's what sharding is used for. In our case, `prompt_ids` has shape `(8, 1, 77, 768)`. This array will be split in `8` and each copy of `_generate` will receive an input with shape `(1, 77, 768)`. + +We can code `_generate` completely ignoring the fact that it will be invoked in parallel. We just care about our batch size (`1` in this example) and the dimensions that make sense for our code, and don't have to change anything to make it work in parallel. + +The same way as when we used the pipeline call, the first time we run the following cell it will take a while, but then it will be much faster. + +``` +%%time +images = p_generate(prompt_ids, p_params, rng) +images = images.block_until_ready() +images.shape +``` + +```python out +CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s +Wall time: 1min 15s +``` + +```python +images.shape +``` + +```python out +(8, 1, 512, 512, 3) +``` + +We use `block_until_ready()` to correctly measure inference time, because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking will occur automatically when you want to use the result of a computation that has not yet been materialized. \ No newline at end of file diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.mdx b/docs/source/en/using-diffusers/unconditional_image_generation.mdx index b1722517cc26..c0888f94c6c1 100644 --- a/docs/source/en/using-diffusers/unconditional_image_generation.mdx +++ b/docs/source/en/using-diffusers/unconditional_image_generation.mdx @@ -10,43 +10,60 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Unconditional image generation +[[open-in-colab]] -# Unconditional Image Generation +Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on. The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference. Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download. -You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads). -In this guide though, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239): +You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies). + + + +πŸ’‘ Want to train your own unconditional image generation model? Take a look at the training [guide](training/unconditional_training) to learn how to generate your own images. + + + +In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239): ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256") +>>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128") ``` + The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. -Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU. -You can move the generator object to GPU, just like you would in PyTorch. +Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU. +You can move the generator object to a GPU, just like you would in PyTorch: ```python >>> generator.to("cuda") ``` -Now you can use the `generator` on your text prompt: +Now you can use the `generator` to generate an image: ```python >>> image = generator().images[0] ``` -The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). +The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object. -You can save the image by simply calling: +You can save the image by calling: ```python >>> image.save("generated_image.png") ``` +Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality! + diff --git a/docs/source/en/using-diffusers/using_safetensors.mdx b/docs/source/en/using-diffusers/using_safetensors.mdx index 50bcb6b9933b..b522f3236fbb 100644 --- a/docs/source/en/using-diffusers/using_safetensors.mdx +++ b/docs/source/en/using-diffusers/using_safetensors.mdx @@ -75,9 +75,9 @@ And we're equipped with dealing with it. Then in order to use the model, even before the branch gets accepted by the original author you can do: ```python -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="refs/pr/22") +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="refs/pr/22") ``` or you can test it directly online with this [space](https://huggingface.co/spaces/diffusers/check_pr). diff --git a/examples/README.md b/examples/README.md index 4526d44e43d5..d09739768925 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,6 +42,8 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie | [**Text-to-Image fine-tuning**](./text_to_image) | βœ… | βœ… | | [**Textual Inversion**](./textual_inversion) | βœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | [**Dreambooth**](./dreambooth) | βœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) +| [**ControlNet**](./controlnet) | βœ… | βœ… | - +| [**InstructPix2Pix**](./instruct_pix2pix) | βœ… | βœ… | - | [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon. ## Community diff --git a/examples/community/README.md b/examples/community/README.md index ba0cc0344643..11da90764579 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -50,11 +50,11 @@ The following code requires roughly 12GB of GPU RAM. ```python from diffusers import DiffusionPipeline -from transformers import CLIPFeatureExtractor, CLIPModel +from transformers import CLIPImageProcessor, CLIPModel import torch -feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") +feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K") clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16) diff --git a/examples/community/bit_diffusion.py b/examples/community/bit_diffusion.py index c778b6cc6c71..18d5fca5619e 100644 --- a/examples/community/bit_diffusion.py +++ b/examples/community/bit_diffusion.py @@ -238,7 +238,7 @@ def __call__( **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: latents = torch.randn( - (batch_size, self.unet.in_channels, height, width), + (batch_size, self.unet.config.in_channels, height, width), generator=generator, ) latents = decimal_to_bits(latents) * self.bit_scale diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 24f187b41c07..3e29ae50078b 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -199,24 +199,20 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] if not attr.startswith("_"): checkpoint_path_1 = os.path.join(cached_folders[1], attr) if os.path.exists(checkpoint_path_1): - files = list( - ( - *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")), - *glob.glob(os.path.join(checkpoint_path_1, "*.bin")), - ) - ) + files = [ + *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")), + *glob.glob(os.path.join(checkpoint_path_1, "*.bin")), + ] checkpoint_path_1 = files[0] if len(files) > 0 else None if len(cached_folders) < 3: checkpoint_path_2 = None else: checkpoint_path_2 = os.path.join(cached_folders[2], attr) if os.path.exists(checkpoint_path_2): - files = list( - ( - *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), - *glob.glob(os.path.join(checkpoint_path_2, "*.bin")), - ) - ) + files = [ + *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), + *glob.glob(os.path.join(checkpoint_path_2, "*.bin")), + ] checkpoint_path_2 = files[0] if len(files) > 0 else None # For an attr if both checkpoint_path_1 and 2 are None, ignore. # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 68bdf22f9454..3f4ab2ab9f4a 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -5,12 +5,13 @@ from torch import nn from torch.nn import functional as F from torchvision import transforms -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, + DPMSolverMultistepScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel, @@ -63,8 +64,8 @@ def __init__( clip_model: CLIPModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], - feature_extractor: CLIPFeatureExtractor, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( @@ -125,17 +126,12 @@ def cond_fn( ): latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latents, timestep) # predict the noise residual noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called @@ -258,7 +254,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/clip_guided_stable_diffusion_img2img.py b/examples/community/clip_guided_stable_diffusion_img2img.py index c9d2bc6e5931..a72a5a127c72 100644 --- a/examples/community/clip_guided_stable_diffusion_img2img.py +++ b/examples/community/clip_guided_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ AutoencoderKL, DDIMScheduler, DiffusionPipeline, + DPMSolverMultistepScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel, @@ -140,7 +141,7 @@ def __init__( clip_model: CLIPModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], feature_extractor: CLIPFeatureExtractor, ): super().__init__() @@ -263,17 +264,12 @@ def cond_fn( ): latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latents, timestep) # predict the noise residual noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called @@ -418,7 +414,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index eb9627106cbb..95292f5bdae8 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -17,11 +17,13 @@ import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -30,11 +32,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import is_accelerate_available - -from ...utils import deprecate, logging -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from diffusers.utils import deprecate, is_accelerate_available, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -64,7 +62,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -84,7 +82,7 @@ def __init__( DPMSolverMultistepScheduler, ], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -513,7 +511,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index 3a514b4a6dd2..56bd381a9e65 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -15,7 +15,7 @@ # TODO: remove and import from diffusers.utils when the new version of diffusers is released from packaging import version from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -48,7 +48,7 @@ def preprocess(image): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -80,7 +80,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offsensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -92,7 +92,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( @@ -424,7 +424,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (1, self.unet.in_channels, height // 8, width // 8) + latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if self.device.type == "mps": # randn does not exist on mps diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index d3ef83c4f7f3..f50eb6cabc37 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -4,7 +4,7 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict @@ -79,7 +79,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -91,7 +91,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index f772620b5d28..8f33db71b9f3 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -5,7 +5,7 @@ import numpy as np import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict @@ -70,7 +70,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -82,7 +82,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() @@ -320,7 +320,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": @@ -416,7 +416,7 @@ def embed_text(self, text): def get_noise(self, seed, dtype=torch.float32, height=512, width=512): """Takes in random seed and returns corresponding noise vector""" return torch.randn( - (1, self.unet.in_channels, height // 8, width // 8), + (1, self.unet.config.in_channels, height // 8, width // 8), generator=torch.Generator(device=self.device).manual_seed(seed), device=self.device, dtype=dtype, diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index dedc31a0913a..e912ad5244be 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -6,7 +6,7 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline @@ -179,14 +179,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m return tokens, weights -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: @@ -317,12 +317,14 @@ def get_weighted_text_embeddings( # pad the length of tokens and weights bos = pipe.tokenizer.bos_token_id eos = pipe.tokenizer.eos_token_id + pad = getattr(pipe.tokenizer, "pad_token_id", eos) prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -334,6 +336,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -376,7 +379,7 @@ def get_weighted_text_embeddings( def preprocess_image(image): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -387,7 +390,7 @@ def preprocess_image(image): def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) @@ -422,7 +425,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -436,7 +439,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__( @@ -461,7 +464,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__( vae=vae, @@ -624,7 +627,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, dev if image is None: shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index eb27e0cd9b7b..e756097cb7c3 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -6,7 +6,7 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTokenizer import diffusers from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin @@ -196,14 +196,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): return tokens, weights -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: @@ -342,12 +342,14 @@ def get_weighted_text_embeddings( # pad the length of tokens and weights bos = pipe.tokenizer.bos_token_id eos = pipe.tokenizer.eos_token_id + pad = getattr(pipe.tokenizer, "pad_token_id", eos) prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -359,6 +361,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -403,7 +406,7 @@ def get_weighted_text_embeddings( def preprocess_image(image): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -413,7 +416,7 @@ def preprocess_image(image): def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) @@ -441,7 +444,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: SchedulerMixin, safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__( @@ -468,7 +471,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: SchedulerMixin, safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__( vae_encoder=vae_encoder, @@ -483,7 +486,7 @@ def __init__( self.__init__additional__() def __init__additional__(self): - self.unet_in_channels = 4 + self.unet.config.in_channels = 4 self.vae_scale_factor = 8 def _encode_prompt( @@ -618,7 +621,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, gen if image is None: shape = ( batch_size, - self.unet_in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py index b1d69ec84576..4eb99cb96b42 100644 --- a/examples/community/magic_mix.py +++ b/examples/community/magic_mix.py @@ -93,7 +93,7 @@ def __call__( torch.manual_seed(seed) noise = torch.randn( - (1, self.unet.in_channels, height // 8, width // 8), + (1, self.unet.config.in_channels, height // 8, width // 8), ).to(self.device) latents = self.scheduler.add_noise( diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index b49298113daf..ff6c7e68f783 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -3,7 +3,7 @@ import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, MBart50TokenizerFast, @@ -79,7 +79,7 @@ class MultilingualStableDiffusion(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -94,7 +94,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() @@ -355,7 +355,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/one_step_unet.py b/examples/community/one_step_unet.py index f3eaf1e0eb7a..7d34bfd83191 100755 --- a/examples/community/one_step_unet.py +++ b/examples/community/one_step_unet.py @@ -12,7 +12,7 @@ def __init__(self, unet, scheduler): def __call__(self): image = torch.randn( - (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), ) timestep = 1 diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index c8fb309e4de3..b7fbc46b67cb 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -65,7 +65,7 @@ class StableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -105,7 +105,7 @@ def __init__( ) model = ModelWrapper(unet, scheduler.alphas_cumprod) - if scheduler.prediction_type == "v_prediction": + if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) @@ -433,7 +433,7 @@ def __call__( sigmas = sigmas.to(text_embeddings.dtype) # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py index 92863ae65412..5891b9fb11a8 100644 --- a/examples/community/seed_resize_stable_diffusion.py +++ b/examples/community/seed_resize_stable_diffusion.py @@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -42,7 +42,7 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -54,7 +54,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( @@ -262,8 +262,8 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) - latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) + latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py index 0ba4d6cb726b..55d805bc8c32 100644 --- a/examples/community/speech_to_image_diffusion.py +++ b/examples/community/speech_to_image_diffusion.py @@ -3,7 +3,7 @@ import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, WhisperForConditionalGeneration, @@ -37,7 +37,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() @@ -190,7 +190,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/stable_diffusion_comparison.py b/examples/community/stable_diffusion_comparison.py index 8b2980442390..7997a0cc0186 100644 --- a/examples/community/stable_diffusion_comparison.py +++ b/examples/community/stable_diffusion_comparison.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -46,7 +46,7 @@ class StableDiffusionComparisonPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -58,7 +58,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super()._init_() diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index ec23564ae3cb..a8a51b5489a3 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -1,15 +1,16 @@ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, @@ -86,7 +87,14 @@ def prepare_image(image): def prepare_controlnet_conditioning_image( - controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype + controlnet_conditioning_image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, ): if not isinstance(controlnet_conditioning_image, torch.Tensor): if isinstance(controlnet_conditioning_image, PIL.Image.Image): @@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image( controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) + return controlnet_conditioning_image @@ -132,10 +143,10 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -156,6 +167,9 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -276,8 +290,7 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -425,6 +438,42 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + else: + raise ValueError("controlnet condition image is not valid") + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + else: + raise ValueError("prompt or prompt_embeds are not valid") + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, @@ -439,6 +488,7 @@ def check_inputs( strength=None, controlnet_guidance_start=None, controlnet_guidance_end=None, + controlnet_conditioning_scale=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -477,58 +527,51 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) - controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) - controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], PIL.Image.Image - ) - controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], torch.Tensor - ) + # check controlnet condition image - if ( - not controlnet_cond_image_is_pil - and not controlnet_cond_image_is_tensor - and not controlnet_cond_image_is_pil_list - and not controlnet_cond_image_is_tensor_list - ): - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) + if isinstance(self.controlnet, ControlNetModel): + self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, MultiControlNetModel): + if not isinstance(controlnet_conditioning_image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") - if controlnet_cond_image_is_pil: - controlnet_cond_image_batch_size = 1 - elif controlnet_cond_image_is_tensor: - controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] - elif controlnet_cond_image_is_pil_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - elif controlnet_cond_image_is_tensor_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) + if len(controlnet_conditioning_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] + for image_ in controlnet_conditioning_image: + self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds) + else: + assert False - if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" - ) + # Check `controlnet_conditioning_scale` + + if isinstance(self.controlnet, ControlNetModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(self.controlnet, MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False if isinstance(image, torch.Tensor): if image.ndim != 3 and image.ndim != 4: raise ValueError("`image` must have 3 or 4 dimensions") - # if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4: - # raise ValueError("`mask_image` must have 2, 3, or 4 dimensions") - if image.ndim == 3: image_batch_size = 1 image_channels, image_height, image_width = image.shape elif image.ndim == 4: image_batch_size, image_channels, image_height, image_width = image.shape + else: + assert False if image_channels != 3: raise ValueError("`image` must have 3 channels") @@ -660,7 +703,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_guidance_start: float = 0.0, controlnet_guidance_end: float = 1.0, ): @@ -699,8 +742,7 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -761,7 +803,6 @@ def __call__( self.check_inputs( prompt, image, - # mask_image, controlnet_conditioning_image, height, width, @@ -772,6 +813,7 @@ def __call__( strength, controlnet_guidance_start, controlnet_guidance_end, + controlnet_conditioning_scale, ) # 2. Define call parameters @@ -788,6 +830,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -799,22 +844,41 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Prepare mask, image, and controlnet_conditioning_image + # 4. Prepare image, and controlnet_conditioning_image image = prepare_image(image) - # mask_image = prepare_mask_image(mask_image) + # condition image(s) + if isinstance(self.controlnet, ControlNetModel): + controlnet_conditioning_image = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=controlnet_conditioning_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + elif isinstance(self.controlnet, MultiControlNetModel): + controlnet_conditioning_images = [] + + for image_ in controlnet_conditioning_image: + image_ = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) - controlnet_conditioning_image = prepare_controlnet_conditioning_image( - controlnet_conditioning_image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) + controlnet_conditioning_images.append(image_) - # masked_image = image * (mask_image < 0.5) + controlnet_conditioning_image = controlnet_conditioning_images + else: + assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -832,9 +896,6 @@ def __call__( generator, ) - if do_classifier_free_guidance: - controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -864,15 +925,10 @@ def __call__( t, encoder_hidden_states=prompt_embeds, controlnet_cond=controlnet_conditioning_image, + conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index b7c8a2a7a7f0..c47f4c3194e8 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -7,7 +7,7 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker @@ -233,7 +233,7 @@ def __init__( controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -373,8 +373,7 @@ def _encode_prompt( do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -833,8 +832,7 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index f435a3274f45..bad1df0e13fb 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -7,7 +7,7 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker @@ -233,7 +233,7 @@ def __init__( controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -373,8 +373,7 @@ def _encode_prompt( do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -876,8 +875,7 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index 1c4af893cd2f..0fec5557a637 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -2,7 +2,7 @@ import PIL.Image import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -60,7 +60,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/examples/community/stable_unclip.py b/examples/community/stable_unclip.py index 8ff9c44d19fd..1b438c8fcb3e 100644 --- a/examples/community/stable_unclip.py +++ b/examples/community/stable_unclip.py @@ -46,7 +46,7 @@ def __init__( ): super().__init__() - decoder_pipe_kwargs = dict(image_encoder=None) if decoder_pipe_kwargs is None else decoder_pipe_kwargs + decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py index be2d6f4d3d5b..99a488788a0d 100644 --- a/examples/community/text_inpainting.py +++ b/examples/community/text_inpainting.py @@ -3,7 +3,7 @@ import PIL import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor, CLIPTextModel, @@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -66,7 +66,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() diff --git a/examples/community/unclip_image_interpolation.py b/examples/community/unclip_image_interpolation.py index fc313acd07bd..d0b54125b688 100644 --- a/examples/community/unclip_image_interpolation.py +++ b/examples/community/unclip_image_interpolation.py @@ -5,7 +5,7 @@ import torch from torch.nn import functional as F from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, @@ -50,7 +50,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of @@ -75,7 +75,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection super_res_first: UNet2DModel super_res_last: UNet2DModel @@ -90,7 +90,7 @@ def __init__( text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, super_res_first: UNet2DModel, super_res_last: UNet2DModel, @@ -270,7 +270,7 @@ def __call__( The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed. + `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed. steps (`int`, *optional*, defaults to 5): The number of interpolation images to generate. decoder_num_inference_steps (`int`, *optional*, defaults to 25): diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py index da2948cea6cb..aec79fb8e12e 100644 --- a/examples/community/wildcard_stable_diffusion.py +++ b/examples/community/wildcard_stable_diffusion.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict @@ -104,7 +104,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -116,7 +116,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): super().__init__() @@ -337,7 +337,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 32de31e14bbd..4b388d92a195 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -267,3 +267,149 @@ image = pipe( image.save("./output.png") ``` + +## Training with Flax/JAX + +For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. + +### Running on Google Cloud TPU + +See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax). + +First create a single TPUv4-8 VM and connect to it: + +``` +ZONE=us-central2-b +TPU_TYPE=v4-8 +VM_NAME=hg_flax + +gcloud alpha compute tpus tpu-vm create $VM_NAME \ + --zone $ZONE \ + --accelerator-type $TPU_TYPE \ + --version tpu-vm-v4-base + +gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ +``` + +When connected install JAX `0.4.5`: + +``` +pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +To verify that JAX was correctly installed, you can run the following command: + +``` +import jax +jax.device_count() +``` + +This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM. + +Then install Diffusers and the library's training dependencies: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run + +```bash +pip install -U -r requirements_flax.txt +``` + +If you want to use Weights and Biases logging, you should also install `wandb` now + +```bash +pip install wandb +``` + +Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress + +``` +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already): + +``` +huggingface-cli login +``` + +Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: + +```bash +export MODEL_DIR="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="control_out" +export HUB_MODEL_ID="fill-circle-controlnet" +``` + +And finally start the training + +```bash +python3 train_controlnet_flax.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --validation_steps=1000 \ + --train_batch_size=2 \ + --revision="non-ema" \ + --from_pt \ + --report_to="wandb" \ + --max_train_steps=10000 \ + --push_to_hub \ + --hub_model_id=$HUB_MODEL_ID + ``` + +Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). + +Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command: + +```bash +python3 train_controlnet_flax.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=multimodalart/facesyntheticsspigacaptioned \ + --streaming \ + --conditioning_image_column=spiga_seg \ + --image_column=image \ + --caption_column=image_caption \ + --resolution=512 \ + --max_train_samples 50 \ + --max_train_steps 5 \ + --learning_rate=1e-5 \ + --validation_steps=2 \ + --train_batch_size=1 \ + --revision="flax" \ + --report_to="wandb" +``` + +Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: + +* [Webdataset](https://webdataset.github.io/webdataset/) +* [TorchData](https://github.com/pytorch/data) +* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) + +When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing: + +```bash + --checkpointing_steps=500 +``` +This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500 + +You can then start your training from this saved checkpoint with + +```bash + --controlnet_model_name_or_path="./control_out/500" +``` + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. + +We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). \ No newline at end of file diff --git a/examples/controlnet/requirements.txt b/examples/controlnet/requirements.txt index 5deb15969f09..d19c62296702 100644 --- a/examples/controlnet/requirements.txt +++ b/examples/controlnet/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/controlnet/requirements_flax.txt b/examples/controlnet/requirements_flax.txt new file mode 100644 index 000000000000..b6eb64e25462 --- /dev/null +++ b/examples/controlnet/requirements_flax.txt @@ -0,0 +1,9 @@ +transformers>=4.25.1 +datasets +flax +optax +torch +torchvision +ftfy +tensorboard +Jinja2 diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 6c14e8ca10db..30e43075d809 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -19,7 +19,6 @@ import os import random from pathlib import Path -from typing import Optional import accelerate import numpy as np @@ -31,7 +30,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms @@ -56,7 +55,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__) @@ -106,7 +105,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler image_logs = [] for validation_prompt, validation_image in zip(validation_prompts, validation_images): - validation_image = Image.open(validation_image) + validation_image = Image.open(validation_image).convert("RGB") images = [] @@ -543,16 +542,13 @@ def make_train_dataset(args, tokenizer, accelerator): cache_dir=args.cache_dir, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets. @@ -581,7 +577,7 @@ def make_train_dataset(args, tokenizer, accelerator): if args.conditioning_image_column is None: conditioning_image_column = column_names[2] - logger.info(f"conditioning image column defaulting to {caption_column}") + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") else: conditioning_image_column = args.conditioning_image_column if conditioning_image_column not in column_names: @@ -661,16 +657,6 @@ def collate_fn(examples): } -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -704,22 +690,14 @@ def main(args): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) @@ -994,8 +972,10 @@ def load_model_hook(models, input_dir): noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type @@ -1053,7 +1033,12 @@ def load_model_hook(models, input_dir): controlnet.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py new file mode 100644 index 000000000000..f5ea3ce84bf3 --- /dev/null +++ b/examples/controlnet/train_controlnet_flax.py @@ -0,0 +1,1089 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import torch +import torch.utils.checkpoint +import transformers +from datasets import load_dataset, load_from_disk +from flax import jax_utils +from flax.core.frozen_dict import unfreeze +from flax.training import train_state +from flax.training.common_utils import shard +from huggingface_hub import create_repo, upload_folder +from PIL import Image, PngImagePlugin +from torch.utils.data import IterableDataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed + +from diffusers import ( + FlaxAutoencoderKL, + FlaxControlNetModel, + FlaxDDPMScheduler, + FlaxStableDiffusionControlNetPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.utils import check_min_version, is_wandb_available + + +# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image +# see more https://github.com/python-pillow/Pillow/issues/5610 +LARGE_ENOUGH_NUMBER = 100 +PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.15.0") + +logger = logging.getLogger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): + logger.info("Running validation... ") + + pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=tokenizer, + controlnet=controlnet, + safety_checker=None, + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, + ) + params = jax_utils.replicate(params) + params["controlnet"] = controlnet_params + + num_samples = jax.device_count() + prng_seed = jax.random.split(rng, jax.device_count()) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + prompts = num_samples * [validation_prompt] + prompt_ids = pipeline.prepare_text_inputs(prompts) + prompt_ids = shard(prompt_ids) + + validation_image = Image.open(validation_image).convert("RGB") + processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) + processed_image = shard(processed_image) + images = pipeline( + prompt_ids=prompt_ids, + image=processed_image, + params=params, + prng_seed=prng_seed, + num_inference_steps=50, + jit=True, + ).images + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + images = pipeline.numpy_to_pil(images) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + if args.report_to == "wandb": + formatted_images = [] + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + wandb.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {args.report_to}") + + return image_logs + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet- {repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--from_pt", + action="store_true", + help="Load the pretrained model from a PyTorch checkpoint.", + ) + parser.add_argument( + "--controlnet_revision", + type=str, + default=None, + help="Revision of controlnet model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_from_pt", + action="store_true", + help="Load the controlnet model from a PyTorch checkpoint.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=5000, + help=("Save a checkpoint of the training state every X updates."), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help=("log training metric every X steps to `--report_t`"), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.") + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder." + "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ." + "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--load_from_disk", + action="store_true", + help=( + "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`" + "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk" + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set. Needed if `streaming` is set to True." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` and logging the images." + ), + ) + parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams).")) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet_flax", + help=("The `project` argument passed to wandb"), + ) + parser.add_argument( + "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over" + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + # This idea comes from + # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370 + if args.streaming and args.max_train_samples is None: + raise ValueError("You must specify `max_train_samples` when using dataset streaming.") + + return args + + +def make_train_dataset(args, tokenizer, batch_size=None): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + streaming=args.streaming, + ) + else: + if args.train_data_dir is not None: + if args.load_from_disk: + dataset = load_from_disk( + args.train_data_dir, + ) + else: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if isinstance(dataset["train"], IterableDataset): + column_names = next(iter(dataset["train"])).keys() + else: + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {caption_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["input_ids"] = tokenize_captions(examples) + + return examples + + if jax.process_index() == 0: + if args.max_train_samples is not None: + if args.streaming: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples) + else: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + if args.streaming: + train_dataset = dataset["train"].map( + preprocess_train, + batched=True, + batch_size=batch_size, + remove_columns=list(dataset["train"].features.keys()), + ) + else: + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + batch = { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } + batch = {k: v.numpy() for k, v in batch.items()} + return batch + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # wandb init + if jax.process_index() == 0 and args.report_to == "wandb": + wandb.init( + entity=args.wandb_entity, + project=args.tracker_project_name, + job_type="train", + config=args, + ) + + if args.seed is not None: + set_seed(args.seed) + + rng = jax.random.PRNGKey(0) + + # Handle the repository creation + if jax.process_index() == 0: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + else: + raise NotImplementedError("No tokenizer specified!") + + # Get the datasets: you can either provide your own training and evaluation files (see below) + total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps + train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=not args.streaming, + collate_fn=collate_fn, + batch_size=total_train_batch_size, + num_workers=args.dataloader_num_workers, + drop_last=True, + ) + + weight_dtype = jnp.float32 + if args.mixed_precision == "fp16": + weight_dtype = jnp.float16 + elif args.mixed_precision == "bf16": + weight_dtype = jnp.bfloat16 + + # Load models and create wrapper for stable diffusion + text_encoder = FlaxCLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, + ) + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + subfolder="vae", + dtype=weight_dtype, + from_pt=args.from_pt, + ) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + dtype=weight_dtype, + revision=args.revision, + from_pt=args.from_pt, + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + args.controlnet_model_name_or_path, + revision=args.controlnet_revision, + from_pt=args.controlnet_from_pt, + dtype=jnp.float32, + ) + else: + logger.info("Initializing controlnet weights from unet") + rng, rng_params = jax.random.split(rng) + + controlnet = FlaxControlNetModel( + in_channels=unet.config.in_channels, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + attention_head_dim=unet.config.attention_head_dim, + cross_attention_dim=unet.config.cross_attention_dim, + use_linear_projection=unet.config.use_linear_projection, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + ) + controlnet_params = controlnet.init_weights(rng=rng_params) + controlnet_params = unfreeze(controlnet_params) + for key in [ + "conv_in", + "time_embedding", + "down_blocks_0", + "down_blocks_1", + "down_blocks_2", + "down_blocks_3", + "mid_block", + ]: + controlnet_params[key] = unet_params[key] + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + adamw = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + optimizer = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + adamw, + ) + + state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer) + + noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + + # Initialize our training + validation_rng, train_rngs = jax.random.split(rng) + train_rngs = jax.random.split(train_rngs, jax.local_device_count()) + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler_state.common.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + alpha = sqrt_alphas_cumprod[timesteps] + sigma = sqrt_one_minus_alphas_cumprod[timesteps] + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): + # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 + if args.gradient_accumulation_steps > 1: + grad_steps = args.gradient_accumulation_steps + batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch) + + def compute_loss(params, minibatch, sample_rng): + # Convert images to latent space + vae_outputs = vae.apply( + {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder( + minibatch["input_ids"], + params=text_encoder_params, + train=False, + )[0] + + controlnet_cond = minibatch["conditioning_pixel_values"] + + # Predict the noise residual and compute loss + down_block_res_samples, mid_block_res_sample = controlnet.apply( + {"params": params}, + noisy_latents, + timesteps, + encoder_hidden_states, + controlnet_cond, + train=True, + return_dict=False, + ) + + model_pred = unet.apply( + {"params": unet_params}, + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = (target - model_pred) ** 2 + + if args.snr_gamma is not None: + snr = jnp.array(compute_snr(timesteps)) + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + loss = loss * snr_loss_weights + + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + + # get a minibatch (one gradient accumulation slice) + def get_minibatch(batch, grad_idx): + return jax.tree_util.tree_map( + lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), + batch, + ) + + def loss_and_grad(grad_idx, train_rng): + # create minibatch for the grad step + minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch + sample_rng, train_rng = jax.random.split(train_rng, 2) + loss, grad = grad_fn(state.params, minibatch, sample_rng) + return loss, grad, train_rng + + if args.gradient_accumulation_steps == 1: + loss, grad, new_train_rng = loss_and_grad(None, train_rng) + else: + init_loss_grad_rng = ( + 0.0, # initial value for cumul_loss + jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad + train_rng, # initial value for train_rng + ) + + def cumul_grad_step(grad_idx, loss_grad_rng): + cumul_loss, cumul_grad, train_rng = loss_grad_rng + loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng) + cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad)) + return cumul_loss, cumul_grad, new_train_rng + + loss, grad, new_train_rng = jax.lax.fori_loop( + 0, + args.gradient_accumulation_steps, + cumul_grad_step, + init_loss_grad_rng, + ) + loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad)) + + grad = jax.lax.pmean(grad, "batch") + + new_state = state.apply_gradients(grads=grad) + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_state, metrics, new_train_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + unet_params = jax_utils.replicate(unet_params) + text_encoder_params = jax_utils.replicate(text_encoder.params) + vae_params = jax_utils.replicate(vae_params) + + # Train! + if args.streaming: + dataset_length = args.max_train_samples + else: + dataset_length = len(train_dataloader) + num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}") + + if jax.process_index() == 0: + wandb.define_metric("*", step_metric="train/step") + wandb.config.update( + { + "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset), + "total_train_batch_size": total_train_batch_size, + "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, + "num_devices": jax.device_count(), + } + ) + + global_step = 0 + epochs = tqdm( + range(args.num_train_epochs), + desc="Epoch ... ", + position=0, + disable=jax.process_index() > 0, + ) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = ( + args.max_train_samples // total_train_batch_size + if args.streaming + else len(train_dataset) // total_train_batch_size + ) + train_step_progress_bar = tqdm( + total=steps_per_epoch, + desc="Training...", + position=1, + leave=False, + disable=jax.process_index() > 0, + ) + # train + for batch in train_dataloader: + batch = shard(batch) + state, train_metric, train_rngs = p_train_step( + state, unet_params, text_encoder_params, vae_params, batch, train_rngs + ) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + global_step += 1 + if global_step >= args.max_train_steps: + break + + if ( + args.validation_prompt is not None + and global_step % args.validation_steps == 0 + and jax.process_index() == 0 + ): + _ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + + if global_step % args.logging_steps == 0 and jax.process_index() == 0: + if args.report_to == "wandb": + wandb.log( + { + "train/step": global_step, + "train/epoch": epoch, + "train/loss": jax_utils.unreplicate(train_metric)["loss"], + } + ) + if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: + controlnet.save_pretrained( + f"{args.output_dir}/{global_step}", + params=get_params_to_save(state.params), + ) + + train_metric = jax_utils.unreplicate(train_metric) + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + if args.validation_prompt is not None: + image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + else: + image_logs = None + + controlnet.save_pretrained( + args.output_dir, + params=get_params_to_save(state.params), + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + +if __name__ == "__main__": + main() diff --git a/examples/dreambooth/requirements.txt b/examples/dreambooth/requirements.txt index 7d93f3d03bd8..7a612982f4ab 100644 --- a/examples/dreambooth/requirements.txt +++ b/examples/dreambooth/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 414ecdeb1fb7..141aafb85128 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -21,7 +21,6 @@ import os import warnings from pathlib import Path -from typing import Optional import accelerate import numpy as np @@ -32,7 +31,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -57,7 +56,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__) @@ -417,6 +416,16 @@ def parse_args(input_args=None): ), ) + parser.add_argument( + "--offset_noise", + action="store_true", + default=False, + help=( + "Fine-tuning against a modified noise" + " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." + ), + ) + if input_args is not None: args = parser.parse_args(input_args) else: @@ -565,16 +574,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -667,22 +666,14 @@ def main(args): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) @@ -943,7 +934,12 @@ def load_model_hook(models, input_dir): latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + if args.offset_noise: + noise = torch.randn_like(latents) + 0.1 * torch.randn( + latents.shape[0], latents.shape[1], 1, 1, device=latents.device + ) + else: + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -1028,7 +1024,12 @@ def load_model_hook(models, input_dir): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 46edd5399e88..8c2faa7ec877 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -22,7 +22,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from diffusers import ( FlaxAutoencoderKL, @@ -36,7 +36,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) @@ -652,7 +652,7 @@ def checkpoint(step=None): tokenizer=tokenizer, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), ) outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index daef268ff8f3..a117bd394895 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -20,7 +20,6 @@ import os import warnings from pathlib import Path -from typing import Optional import numpy as np import torch @@ -30,7 +29,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -54,12 +53,12 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__) -def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None): +def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -80,7 +79,7 @@ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_fol --- """ model_card = f""" -# LoRA DreamBooth - {repo_name} +# LoRA DreamBooth - {repo_id} These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n {img_str} @@ -528,16 +527,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -625,23 +614,14 @@ def main(args): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) @@ -1027,13 +1007,18 @@ def main(args): if args.push_to_hub: save_model_card( - repo_name, + repo_id, images=images, base_model=args.pretrained_model_name_or_path, prompt=args.instance_prompt, repo_folder=args.output_dir, ) - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/instruct_pix2pix/requirements.txt b/examples/instruct_pix2pix/requirements.txt index 176ef92a1424..e18cc9e4215e 100644 --- a/examples/instruct_pix2pix/requirements.txt +++ b/examples/instruct_pix2pix/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 datasets diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 57430b7f150a..b542d01c112a 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -21,7 +21,6 @@ import math import os from pathlib import Path -from typing import Optional import accelerate import datasets @@ -37,7 +36,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -52,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__, log_level="INFO") @@ -363,16 +362,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def convert_to_np(image, resolution): image = image.convert("RGB").resize((resolution, resolution)) return np.array(image).transpose(2, 0, 1) @@ -436,22 +425,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( @@ -470,19 +451,18 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - if accelerator.is_main_process: - logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - in_channels = 8 - out_channels = unet.conv_in.out_channels - unet.register_to_config(in_channels=in_channels) - - with torch.no_grad(): - new_conv_in = nn.Conv2d( - in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding - ) - new_conv_in.weight.zero_() - new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) - unet.conv_in = new_conv_in + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in # Freeze vae and text_encoder vae.requires_grad_(False) @@ -673,7 +653,7 @@ def preprocess_train(examples): examples["edited_pixel_values"] = edited_images # Preprocess the captions. - captions = [caption for caption in examples[edit_prompt_column]] + captions = list(examples[edit_prompt_column]) examples["input_ids"] = tokenize_captions(captions) return examples @@ -813,7 +793,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -911,9 +891,12 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) + # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unet, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), revision=args.revision, torch_dtype=weight_dtype, ) @@ -923,7 +906,9 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): for _ in range(args.num_validation_images): edited_images.append( pipeline( @@ -968,12 +953,17 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if args.validation_prompt is not None: edited_images = [] pipeline = pipeline.to(accelerator.device) - with torch.autocast(str(accelerator.device)): + with torch.autocast(str(accelerator.device).replace(":0", "")): for _ in range(args.num_validation_images): edited_images.append( pipeline( diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py index 6136f7233900..3d4466bf94b7 100644 --- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py +++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py @@ -3,7 +3,6 @@ import math import os from pathlib import Path -from typing import Optional import colossalai import torch @@ -16,7 +15,7 @@ from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from PIL import Image from torch.utils.data import Dataset from torchvision import transforms @@ -344,16 +343,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - # Gemini + ZeRO DDP def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP @@ -413,22 +402,14 @@ def main(args): # Handle the repository creation if local_rank == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) @@ -679,7 +660,12 @@ def collate_fn(examples): logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": diff --git a/examples/research_projects/dreambooth_inpaint/requirements.txt b/examples/research_projects/dreambooth_inpaint/requirements.txt index f17dfab9653b..aad6387026f1 100644 --- a/examples/research_projects/dreambooth_inpaint/requirements.txt +++ b/examples/research_projects/dreambooth_inpaint/requirements.txt @@ -1,5 +1,5 @@ diffusers==0.9.0 -accelerate +accelerate>=0.16.0 torchvision transformers>=4.21.0 ftfy diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py index 247361d21299..5158f9fc3bc0 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py @@ -5,7 +5,6 @@ import os import random from pathlib import Path -from typing import Optional import numpy as np import torch @@ -14,7 +13,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from PIL import Image, ImageDraw from torch.utils.data import Dataset from torchvision import transforms @@ -402,28 +401,18 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir, - accelerator_project_config=accelerator_project_config, + project_config=project_config, ) # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate @@ -485,22 +474,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -816,7 +797,12 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index e415e6965317..07df6f201175 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -4,7 +4,6 @@ import os import random from pathlib import Path -from typing import Optional import numpy as np import torch @@ -13,7 +12,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from PIL import Image, ImageDraw from torch.utils.data import Dataset from torchvision import transforms @@ -401,16 +400,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) @@ -422,7 +411,7 @@ def main(): mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir, - accelerator_project_config=accelerator_project_config, + project_config=accelerator_project_config, ) # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate @@ -484,22 +473,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -835,7 +816,12 @@ def collate_fn(examples): unet.save_attn_procs(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/intel_opts/README.md b/examples/research_projects/intel_opts/README.md index fc606df7d170..6b25679efbe9 100644 --- a/examples/research_projects/intel_opts/README.md +++ b/examples/research_projects/intel_opts/README.md @@ -11,6 +11,26 @@ We accelereate the fine-tuning for textual inversion with Intel Extension for Py ## Accelerating the inference for Stable Diffusion using Bfloat16 We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support. +```bash +pip install diffusers transformers accelerate scipy safetensors + +export KMP_BLOCKTIME=1 +export KMP_SETTINGS=1 +export KMP_AFFINITY=granularity=fine,compact,1,0 + +# Intel OpenMP +export OMP_NUM_THREADS=< Cores to use > +export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so +# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support. +export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so +export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000" + +# Launch with default DDIM +numactl --membind -C python python inference_bf16.py +# Launch with DPMSolverMultistepScheduler +numactl --membind -C python python inference_bf16.py --dpm + +``` ## Accelerating the inference for Stable Diffusion using INT8 diff --git a/examples/research_projects/intel_opts/inference_bf16.py b/examples/research_projects/intel_opts/inference_bf16.py index 8431693a45c8..96ec709f433c 100644 --- a/examples/research_projects/intel_opts/inference_bf16.py +++ b/examples/research_projects/intel_opts/inference_bf16.py @@ -1,49 +1,56 @@ +import argparse + import intel_extension_for_pytorch as ipex import torch -from PIL import Image - -from diffusers import StableDiffusionPipeline - -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols +from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid +parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False) +parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not") +parser.add_argument("--steps", default=None, type=int, help="Num inference steps") +args = parser.parse_args() -prompt = ["a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings"] -batch_size = 8 -prompt = prompt * batch_size - device = "cpu" +prompt = "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings" + model_id = "path-to-your-trained-model" -model = StableDiffusionPipeline.from_pretrained(model_id) -model = model.to(device) +pipe = StableDiffusionPipeline.from_pretrained(model_id) +if args.dpm: + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to(device) # to channels last -model.unet = model.unet.to(memory_format=torch.channels_last) -model.vae = model.vae.to(memory_format=torch.channels_last) -model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last) -model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last) +pipe.unet = pipe.unet.to(memory_format=torch.channels_last) +pipe.vae = pipe.vae.to(memory_format=torch.channels_last) +pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last) +if pipe.requires_safety_checker: + pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last) # optimize with ipex -model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True) -model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True) -model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) -model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) +sample = torch.randn(2, 4, 64, 64) +timestep = torch.rand(1) * 999 +encoder_hidden_status = torch.randn(2, 77, 768) +input_example = (sample, timestep, encoder_hidden_status) +try: + pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example) +except Exception: + pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True) +pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True) +pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) +if pipe.requires_safety_checker: + pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) # compute seed = 666 generator = torch.Generator(device).manual_seed(seed) +generate_kwargs = {"generator": generator} +if args.steps is not None: + generate_kwargs["num_inference_steps"] = args.steps + with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): - images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images + image = pipe(prompt, **generate_kwargs).images[0] - # save image - grid = image_grid(images, rows=2, cols=4) - grid.save(model_id + ".png") +# save image +image.save("generated.png") diff --git a/examples/research_projects/intel_opts/textual_inversion/requirements.txt b/examples/research_projects/intel_opts/textual_inversion/requirements.txt index 17b32ea8a271..af7ed6b21f6f 100644 --- a/examples/research_projects/intel_opts/textual_inversion/requirements.txt +++ b/examples/research_projects/intel_opts/textual_inversion/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.21.0 ftfy diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index f446efc0b0c0..1580cb392e8d 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -4,7 +4,6 @@ import os import random from pathlib import Path -from typing import Optional import intel_extension_for_pytorch as ipex import numpy as np @@ -15,7 +14,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder # TODO: remove and import from diffusers.utils when the new version of diffusers is released from packaging import version @@ -23,7 +22,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler @@ -356,16 +355,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def freeze_params(params): for param in params: param.requires_grad = False @@ -388,22 +377,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -632,7 +613,7 @@ def main(): tokenizer=tokenizer, scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings @@ -640,7 +621,12 @@ def main(): save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/lora/requirements.txt b/examples/research_projects/lora/requirements.txt index 13b6feeec964..89a1b73e7072 100644 --- a/examples/research_projects/lora/requirements.txt +++ b/examples/research_projects/lora/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 datasets diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index a53af7bcffd2..fd516fff9811 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -22,7 +22,6 @@ import os import random from pathlib import Path -from typing import Optional import datasets import numpy as np @@ -34,7 +33,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -55,7 +54,7 @@ logger = get_logger(__name__, log_level="INFO") -def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): +def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -75,7 +74,7 @@ def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, re --- """ model_card = f""" -# LoRA text2image fine-tuning - {repo_name} +# LoRA text2image fine-tuning - {repo_id} These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n {img_str} """ @@ -386,16 +385,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - DATASET_NAME_MAPPING = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -441,22 +430,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo_name = create_repo(repo_name, exist_ok=True) - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( @@ -542,9 +523,9 @@ def main(): lora_layers = AttnProcsLayers(unet.attn_processors) # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -582,7 +563,7 @@ def main(): else: optimizer_cls = torch.optim.AdamW - if args.peft: + if args.use_peft: # Optimizer creation params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -724,7 +705,7 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - if args.peft: + if args.use_peft: if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -813,7 +794,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -842,7 +823,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - if args.peft: + if args.use_peft: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder @@ -922,18 +903,22 @@ def collate_fn(examples): if accelerator.is_main_process: if args.use_peft: lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + unwarpped_unet = accelerator.unwrap_model(unet) + state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True) if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict( + inference=True + ) - accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt")) - with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f: + accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: json.dump(lora_config, f) else: unet = unet.to(torch.float32) @@ -941,13 +926,18 @@ def collate_fn(examples): if args.push_to_hub: save_model_card( - repo_name, + repo_id, images=images, base_model=args.pretrained_model_name_or_path, dataset_name=args.dataset_name, repo_folder=args.output_dir, ) - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) # Final inference # Load previous pipeline @@ -957,12 +947,12 @@ def collate_fn(examples): if args.use_peft: - def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): - with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: + def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f: lora_config = json.load(f) print(lora_config) - checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt" + checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt") lora_checkpoint_sd = torch.load(checkpoint) unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} text_encoder_lora_ds = { @@ -985,9 +975,7 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): pipe.to(device) return pipe - pipeline = load_and_set_lora_ckpt( - pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype - ) + pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype) else: pipeline = pipeline.to(accelerator.device) @@ -995,7 +983,10 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None images = [] for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) diff --git a/examples/research_projects/mulit_token_textual_inversion/requirements.txt b/examples/research_projects/mulit_token_textual_inversion/requirements.txt index 7d93f3d03bd8..7a612982f4ab 100644 --- a/examples/research_projects/mulit_token_textual_inversion/requirements.txt +++ b/examples/research_projects/mulit_token_textual_inversion/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py index 05f714715fc9..622c51d2e52e 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py @@ -19,7 +19,6 @@ import os import random from pathlib import Path -from typing import Optional import numpy as np import PIL @@ -30,7 +29,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from multi_token_clip import MultiTokenCLIPTokenizer # TODO: remove and import from diffusers.utils when the new version of diffusers is released @@ -547,16 +546,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -596,22 +585,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load tokenizer if args.tokenizer_name: tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -932,7 +913,12 @@ def main(): save_progress(tokenizer, text_encoder, accelerator, save_path) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py index c23fa4f5d38a..ecc89f98298e 100644 --- a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py +++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py @@ -4,7 +4,6 @@ import os import random from pathlib import Path -from typing import Optional import jax import jax.numpy as jnp @@ -17,7 +16,7 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder # TODO: remove and import from diffusers.utils when the new version of diffusers is released from packaging import version @@ -25,7 +24,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from diffusers import ( FlaxAutoencoderKL, @@ -326,16 +325,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): if model.config.vocab_size == new_num_tokens or new_num_tokens is None: return @@ -367,22 +356,14 @@ def main(): set_seed(args.seed) if jax.process_index() == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -640,7 +621,7 @@ def compute_loss(params): tokenizer=tokenizer, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.save_pretrained( @@ -661,7 +642,12 @@ def compute_loss(params): jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": diff --git a/examples/research_projects/multi_subject_dreambooth/requirements.txt b/examples/research_projects/multi_subject_dreambooth/requirements.txt index bbf6c5bec69c..e19b0ce60bf4 100644 --- a/examples/research_projects/multi_subject_dreambooth/requirements.txt +++ b/examples/research_projects/multi_subject_dreambooth/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 2ea6217e576f..a1016b50e7b2 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -6,7 +6,6 @@ import os import warnings from pathlib import Path -from typing import Optional import datasets import torch @@ -16,7 +15,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from PIL import Image from torch.utils.data import Dataset from torchvision import transforms @@ -463,16 +462,6 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -584,22 +573,14 @@ def main(args): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) @@ -886,7 +867,12 @@ def main(args): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/onnxruntime/text_to_image/requirements.txt b/examples/research_projects/onnxruntime/text_to_image/requirements.txt index b597d5464f1e..2dbadea4474a 100644 --- a/examples/research_projects/onnxruntime/text_to_image/requirements.txt +++ b/examples/research_projects/onnxruntime/text_to_image/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 datasets diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 637b35b3f695..61312fb3a4b3 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -19,7 +19,6 @@ import os import random from pathlib import Path -from typing import Optional import datasets import numpy as np @@ -31,7 +30,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from onnxruntime.training.ortmodule import ORTModule from torchvision import transforms from tqdm.auto import tqdm @@ -313,16 +312,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - dataset_name_mapping = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -339,7 +328,7 @@ def main(): mixed_precision=args.mixed_precision, log_with=args.report_to, logging_dir=logging_dir, - accelerator_project_config=accelerator_project_config, + project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. @@ -364,22 +353,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( @@ -660,7 +641,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -732,7 +713,12 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/onnxruntime/textual_inversion/requirements.txt b/examples/research_projects/onnxruntime/textual_inversion/requirements.txt index 3a1731c228fd..c1a94eac83e6 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/requirements.txt +++ b/examples/research_projects/onnxruntime/textual_inversion/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 8d2c4c3c0bd4..a3d24066ad7a 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -19,7 +19,6 @@ import os import random from pathlib import Path -from typing import Optional import datasets import numpy as np @@ -31,7 +30,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from onnxruntime.training.ortmodule import ORTModule # TODO: remove and import from diffusers.utils when the new version of diffusers is released @@ -463,16 +462,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -514,22 +503,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load tokenizer if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -851,7 +832,12 @@ def main(): save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt b/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt index bbc690556020..f366720afd11 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt @@ -1,3 +1,3 @@ -accelerate +accelerate>=0.16.0 torchvision datasets diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py index e64a20500bea..adf6d1443d1c 100644 --- a/examples/rl/run_diffuser_locomotion.py +++ b/examples/rl/run_diffuser_locomotion.py @@ -4,17 +4,17 @@ from diffusers.experimental import ValueGuidedRLPipeline -config = dict( - n_samples=64, - horizon=32, - num_inference_steps=20, - n_guide_steps=2, # can set to 0 for faster sampling, does not use value network - scale_grad_by_std=True, - scale=0.1, - eta=0.0, - t_grad_cutoff=2, - device="cpu", -) +config = { + "n_samples": 64, + "horizon": 32, + "num_inference_steps": 20, + "n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network + "scale_grad_by_std": True, + "scale": 0.1, + "eta": 0.0, + "t_grad_cutoff": 2, + "device": "cpu", +} if __name__ == "__main__": diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 312ebdac524f..c84db0ceee64 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -52,7 +52,7 @@ If you have already cloned the repo, then you won't need to go through these ste With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** - + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" @@ -71,6 +71,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \ --lr_scheduler="constant" --lr_warmup_steps=0 \ --output_dir="sd-pokemon-model" ``` + To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). @@ -110,6 +111,22 @@ image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` +#### Training with Min-SNR weighting + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence +by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended +value when using it is 5.0. + +You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: + +* Training without the Min-SNR weighting strategy +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) + +For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. + +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + ## Training with LoRA Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index a71be6715c15..31b9026efdc2 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 datasets diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6139a0e6514d..fde762814b54 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -19,7 +19,6 @@ import os import random from pathlib import Path -from typing import Optional import accelerate import datasets @@ -32,7 +31,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -42,15 +41,74 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate +from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__, log_level="INFO") +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -112,6 +170,13 @@ def parse_args(): "value if set." ), ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) parser.add_argument( "--output_dir", type=str, @@ -193,6 +258,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -298,6 +370,21 @@ def parse_args(): "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -315,21 +402,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - -dataset_name_mapping = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} - - def main(): args = parse_args() @@ -376,22 +448,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( @@ -429,6 +493,30 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -526,7 +614,7 @@ def load_model_hook(models, input_dir): column_names = dataset["train"].column_names # 6. Get the column names for input/target. - dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: @@ -645,7 +733,9 @@ def collate_fn(examples): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -714,7 +804,7 @@ def collate_fn(examples): bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -734,7 +824,23 @@ def collate_fn(examples): # Predict the noise residual and compute loss model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() @@ -769,6 +875,26 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -786,7 +912,12 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 8655634dfc34..cdfc546a8f58 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -4,7 +4,6 @@ import os import random from pathlib import Path -from typing import Optional import jax import jax.numpy as jnp @@ -17,10 +16,10 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from diffusers import ( FlaxAutoencoderKL, @@ -34,7 +33,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = logging.getLogger(__name__) @@ -222,16 +221,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - dataset_name_mapping = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -261,22 +250,14 @@ def main(): # Handle the repository creation if jax.process_index() == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -567,7 +548,7 @@ def compute_loss(params): tokenizer=tokenizer, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.save_pretrained( @@ -581,7 +562,12 @@ def compute_loss(params): ) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 3b54cc286663..a50ca222a4a0 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -20,7 +20,6 @@ import os import random from pathlib import Path -from typing import Optional import datasets import numpy as np @@ -32,7 +31,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -48,12 +47,12 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__, log_level="INFO") -def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): +def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -73,7 +72,7 @@ def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, re --- """ model_card = f""" -# LoRA text2image fine-tuning - {repo_name} +# LoRA text2image fine-tuning - {repo_id} These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n {img_str} """ @@ -347,16 +346,6 @@ def parse_args(): return args -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - DATASET_NAME_MAPPING = { "lambdalabs/pokemon-blip-captions": ("image", "text"), } @@ -402,22 +391,13 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo_name = create_repo(repo_name, exist_ok=True) - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( @@ -727,7 +707,7 @@ def collate_fn(examples): bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -830,13 +810,18 @@ def collate_fn(examples): if args.push_to_hub: save_model_card( - repo_name, + repo_id, images=images, base_model=args.pretrained_model_name_or_path, dataset_name=args.dataset_name, repo_folder=args.output_dir, ) - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) # Final inference # Load previous pipeline diff --git a/examples/textual_inversion/requirements.txt b/examples/textual_inversion/requirements.txt index 7d93f3d03bd8..7a612982f4ab 100644 --- a/examples/textual_inversion/requirements.txt +++ b/examples/textual_inversion/requirements.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.16.0 torchvision transformers>=4.25.1 ftfy diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 92f3d27d4905..aebc524bbb36 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -20,7 +20,6 @@ import random import warnings from pathlib import Path -from typing import Optional import numpy as np import PIL @@ -31,7 +30,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder # TODO: remove and import from diffusers.utils when the new version of diffusers is released from packaging import version @@ -78,7 +77,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__) @@ -519,16 +518,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -567,22 +556,14 @@ def main(): # Handle the repository creation if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load tokenizer if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -858,7 +839,7 @@ def main(): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: if args.push_to_hub and args.only_save_embeds: @@ -880,7 +861,12 @@ def main(): save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index e988a2552612..513548d947a0 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -4,7 +4,6 @@ import os import random from pathlib import Path -from typing import Optional import jax import jax.numpy as jnp @@ -17,7 +16,7 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard -from huggingface_hub import HfFolder, Repository, create_repo, whoami +from huggingface_hub import create_repo, upload_folder # TODO: remove and import from diffusers.utils when the new version of diffusers is released from packaging import version @@ -25,7 +24,7 @@ from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from diffusers import ( FlaxAutoencoderKL, @@ -57,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = logging.getLogger(__name__) @@ -339,16 +338,6 @@ def __getitem__(self, i): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): if model.config.vocab_size == new_num_tokens or new_num_tokens is None: return @@ -380,22 +369,14 @@ def main(): set_seed(args.seed) if jax.process_index() == 0: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - create_repo(repo_name, exist_ok=True, token=args.hub_token) - repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -667,7 +648,7 @@ def compute_loss(params): tokenizer=tokenizer, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.save_pretrained( @@ -688,7 +669,12 @@ def compute_loss(params): jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": diff --git a/examples/unconditional_image_generation/requirements.txt b/examples/unconditional_image_generation/requirements.txt index bbc690556020..f366720afd11 100644 --- a/examples/unconditional_image_generation/requirements.txt +++ b/examples/unconditional_image_generation/requirements.txt @@ -1,3 +1,3 @@ -accelerate +accelerate>=0.16.0 torchvision datasets diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3b784eda6a34..f38e908fcef6 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -28,7 +28,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.15.0") logger = get_logger(__name__, log_level="INFO") diff --git a/pyproject.toml b/pyproject.toml index 5ec7ae51be15..a5fe70af9ca7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ target-version = ['py37'] [tool.ruff] # Never enforce `E501` (line length violations). -ignore = ["E501", "E741", "W605"] -select = ["E", "F", "I", "W"] +ignore = ["C901", "E501", "E741", "W605"] +select = ["C", "E", "F", "I", "W"] line-length = 119 # Ignore import violations in all `__init__.py` files. diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index 4222327c23de..46595784b0ba 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -404,7 +404,7 @@ def convert_vq_autoenc_checkpoint(checkpoint, config): config = json.loads(f.read()) # unet case - key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys()) + key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()} if "encoder" in key_prefix_set and "decoder" in key_prefix_set: converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config) else: diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py index 9475f7da93fb..cc5321e33fe0 100644 --- a/scripts/convert_models_diffuser_to_diffusers.py +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -24,29 +24,29 @@ def unet(hor): up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") state_dict = model.state_dict() - config = dict( - down_block_types=down_block_types, - block_out_channels=block_out_channels, - up_block_types=up_block_types, - layers_per_block=1, - use_timestep_embedding=True, - out_block_type="OutConv1DBlock", - norm_num_groups=8, - downsample_each_block=False, - in_channels=14, - out_channels=14, - extra_in_channels=0, - time_embedding_type="positional", - flip_sin_to_cos=False, - freq_shift=1, - sample_size=65536, - mid_block_type="MidResTemporalBlock1D", - act_fn="mish", - ) + config = { + "down_block_types": down_block_types, + "block_out_channels": block_out_channels, + "up_block_types": up_block_types, + "layers_per_block": 1, + "use_timestep_embedding": True, + "out_block_type": "OutConv1DBlock", + "norm_num_groups": 8, + "downsample_each_block": False, + "in_channels": 14, + "out_channels": 14, + "extra_in_channels": 0, + "time_embedding_type": "positional", + "flip_sin_to_cos": False, + "freq_shift": 1, + "sample_size": 65536, + "mid_block_type": "MidResTemporalBlock1D", + "act_fn": "mish", + } hf_value_function = UNet1DModel(**config) print(f"length of state dict: {len(state_dict.keys())}") print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") - mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) for k, v in mapping.items(): state_dict[v] = state_dict.pop(k) hf_value_function.load_state_dict(state_dict) @@ -57,25 +57,25 @@ def unet(hor): def value_function(): - config = dict( - in_channels=14, - down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), - up_block_types=(), - out_block_type="ValueFunction", - mid_block_type="ValueFunctionMidBlock1D", - block_out_channels=(32, 64, 128, 256), - layers_per_block=1, - downsample_each_block=True, - sample_size=65536, - out_channels=14, - extra_in_channels=0, - time_embedding_type="positional", - use_timestep_embedding=True, - flip_sin_to_cos=False, - freq_shift=1, - norm_num_groups=8, - act_fn="mish", - ) + config = { + "in_channels": 14, + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": (), + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": (32, 64, 128, 256), + "layers_per_block": 1, + "downsample_each_block": True, + "sample_size": 65536, + "out_channels": 14, + "extra_in_channels": 0, + "time_embedding_type": "positional", + "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "freq_shift": 1, + "norm_num_groups": 8, + "act_fn": "mish", + } model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") state_dict = model @@ -83,7 +83,7 @@ def value_function(): print(f"length of state dict: {len(state_dict.keys())}") print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") - mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys())) for k, v in mapping.items(): state_dict[v] = state_dict.pop(k) diff --git a/scripts/convert_music_spectrogram_to_diffusers.py b/scripts/convert_music_spectrogram_to_diffusers.py new file mode 100644 index 000000000000..41ee8b914774 --- /dev/null +++ b/scripts/convert_music_spectrogram_to_diffusers.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +import argparse +import os + +import jax as jnp +import numpy as onp +import torch +import torch.nn as nn +from music_spectrogram_diffusion import inference +from t5x import checkpoints + +from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline +from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder + + +MODEL = "base_with_context" + + +def load_notes_encoder(weights, model): + model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"])) + model.position_encoding.weight = nn.Parameter( + torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False + ) + for lyr_num, lyr in enumerate(model.encoders): + ly_weight = weights[f"layers_{lyr_num}"] + lyr.layer[0].layer_norm.weight = nn.Parameter( + torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"]) + ) + + attention_weights = ly_weight["attention"] + lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) + + lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + + lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) + + model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"])) + return model + + +def load_continuous_encoder(weights, model): + model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T)) + + model.position_encoding.weight = nn.Parameter( + torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False + ) + + for lyr_num, lyr in enumerate(model.encoders): + ly_weight = weights[f"layers_{lyr_num}"] + attention_weights = ly_weight["attention"] + + lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) + lyr.layer[0].layer_norm.weight = nn.Parameter( + torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"]) + ) + + lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) + lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + + model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"])) + + return model + + +def load_decoder(weights, model): + model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T)) + model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T)) + + model.position_encoding.weight = nn.Parameter( + torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False + ) + + model.continuous_inputs_projection.weight = nn.Parameter( + torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T) + ) + + for lyr_num, lyr in enumerate(model.decoders): + ly_weight = weights[f"layers_{lyr_num}"] + lyr.layer[0].layer_norm.weight = nn.Parameter( + torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"]) + ) + + lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter( + torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T) + ) + + attention_weights = ly_weight["self_attention"] + lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) + lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) + lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) + lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) + + attention_weights = ly_weight["MultiHeadDotProductAttention_0"] + lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T)) + lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T)) + lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T)) + lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T)) + lyr.layer[1].layer_norm.weight = nn.Parameter( + torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"]) + ) + + lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"])) + lyr.layer[2].film.scale_bias.weight = nn.Parameter( + torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T) + ) + lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T)) + lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T)) + lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T)) + + model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"])) + + model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T)) + + return model + + +def main(args): + t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path) + t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint) + + gin_overrides = [ + "from __gin__ import dynamic_registration", + "from music_spectrogram_diffusion.models.diffusion import diffusion_utils", + "diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0", + "diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()", + ] + + gin_file = os.path.join(args.checkpoint_path, "..", "config.gin") + gin_config = inference.parse_training_gin_file(gin_file, gin_overrides) + synth_model = inference.InferenceModel(args.checkpoint_path, gin_config) + + scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large") + + notes_encoder = SpectrogramNotesEncoder( + max_length=synth_model.sequence_length["inputs"], + vocab_size=synth_model.model.module.config.vocab_size, + d_model=synth_model.model.module.config.emb_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + num_layers=synth_model.model.module.config.num_encoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + feed_forward_proj="gated-gelu", + ) + + continuous_encoder = SpectrogramContEncoder( + input_dims=synth_model.audio_codec.n_dims, + targets_context_length=synth_model.sequence_length["targets_context"], + d_model=synth_model.model.module.config.emb_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + num_layers=synth_model.model.module.config.num_encoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + feed_forward_proj="gated-gelu", + ) + + decoder = T5FilmDecoder( + input_dims=synth_model.audio_codec.n_dims, + targets_length=synth_model.sequence_length["targets_context"], + max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time, + d_model=synth_model.model.module.config.emb_dim, + num_layers=synth_model.model.module.config.num_decoder_layers, + num_heads=synth_model.model.module.config.num_heads, + d_kv=synth_model.model.module.config.head_dim, + d_ff=synth_model.model.module.config.mlp_dim, + dropout_rate=synth_model.model.module.config.dropout_rate, + ) + + notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder) + continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder) + decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder) + + melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder") + + pipe = SpectrogramDiffusionPipeline( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + if args.save: + pipe.save_pretrained(args.output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.") + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." + ) + parser.add_argument( + "--checkpoint_path", + default=f"{MODEL}/checkpoint_500000", + type=str, + required=False, + help="Path to the original jax model checkpoint.", + ) + args = parser.parse_args() + + main(args) diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py new file mode 100644 index 000000000000..189b165c0a01 --- /dev/null +++ b/scripts/convert_original_audioldm_to_diffusers.py @@ -0,0 +1,1015 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the AudioLDM checkpoints.""" + +import argparse +import re + +import torch +from transformers import ( + AutoTokenizer, + ClapTextConfig, + ClapTextModelWithProjection, + SpeechT5HifiGan, + SpeechT5HifiGanConfig, +) + +from diffusers import ( + AudioLDMPipeline, + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.utils import is_omegaconf_available, is_safetensors_available +from diffusers.utils.import_utils import BACKENDS_MAPPING + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths +def renew_attention_paths(old_list): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a UNet config for diffusers based on the config of the original AudioLDM model. + """ + unet_params = original_config.model.params.unet_config.params + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + cross_attention_dim = ( + unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels + ) + + class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None + projection_class_embeddings_input_dim = ( + unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None + ) + class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "out_channels": unet_params.out_channels, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": cross_attention_dim, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "class_embeddings_concat": class_embeddings_concat, + } + + return config + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config +def create_vae_diffusers_config(original_config, checkpoint, image_size: int): + """ + Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original + Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + "scaling_factor": float(scaling_factor), + } + return config + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion + conversion, this function additionally converts the learnt film embedding linear layer. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"] + new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +CLAP_KEYS_TO_MODIFY_MAPPING = { + "text_branch": "text_model", + "attn": "attention.self", + "self.proj": "output.dense", + "attention.self_mask": "attn_mask", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm1": "layernorm_before", + "norm2": "layernorm_after", + "bn0": "batch_norm", +} + +CLAP_KEYS_TO_IGNORE = ["text_transform"] + +CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"] + + +def convert_open_clap_checkpoint(checkpoint): + """ + Takes a state dict and returns a converted CLAP checkpoint. + """ + # extract state dict for CLAP text embedding model, discarding the audio component + model_state_dict = {} + model_key = "cond_stage_model.model.text_" + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(model_key): + model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key) + + new_checkpoint = {} + + sequential_layers_pattern = r".*sequential.(\d+).*" + text_projection_pattern = r".*_projection.(\d+).*" + + for key, value in model_state_dict.items(): + # check if key should be ignored in mapping + if key.split(".")[0] in CLAP_KEYS_TO_IGNORE: + continue + + # check if any key needs to be modified + for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(sequential_layers_pattern, key): + # replace sequential layers with list + sequential_layer = re.match(sequential_layers_pattern, key).group(1) + + key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") + elif re.match(text_projection_pattern, key): + projecton_layer = int(re.match(text_projection_pattern, key).group(1)) + + # Because in CLAP they use `nn.Sequential`... + transformers_projection_layer = 1 if projecton_layer == 0 else 2 + + key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") + + if "audio" and "qkv" in key: + # split qkv into query key and value + mixed_qkv = value + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + new_checkpoint[key.replace("qkv", "query")] = query_layer + new_checkpoint[key.replace("qkv", "key")] = key_layer + new_checkpoint[key.replace("qkv", "value")] = value_layer + else: + new_checkpoint[key] = value + + return new_checkpoint + + +def create_transformers_vocoder_config(original_config): + """ + Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. + """ + vocoder_params = original_config.model.params.vocoder_config.params + + config = { + "model_in_dim": vocoder_params.num_mels, + "sampling_rate": vocoder_params.sampling_rate, + "upsample_initial_channel": vocoder_params.upsample_initial_channel, + "upsample_rates": list(vocoder_params.upsample_rates), + "upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), + "resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), + "resblock_dilation_sizes": [ + list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes + ], + "normalize_before": False, + } + + return config + + +def convert_hifigan_checkpoint(checkpoint, config): + """ + Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint. + """ + # extract state dict for vocoder + vocoder_state_dict = {} + vocoder_key = "first_stage_model.vocoder." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vocoder_key): + vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key) + + # fix upsampler keys, everything else is correct already + for i in range(len(config.upsample_rates)): + vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight") + vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias") + + if not config.normalize_before: + # if we don't set normalize_before then these variables are unused, so we set them to their initialised values + vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim) + vocoder_state_dict["scale"] = torch.ones(config.model_in_dim) + + return vocoder_state_dict + + +# Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73 +DEFAULT_CONFIG = { + "model": { + "params": { + "linear_start": 0.0015, + "linear_end": 0.0195, + "timesteps": 1000, + "channels": 8, + "scale_by_std": True, + "unet_config": { + "target": "audioldm.latent_diffusion.openaimodel.UNetModel", + "params": { + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + }, + }, + "first_stage_config": { + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "embed_dim": 8, + "ddconfig": { + "z_channels": 8, + "resolution": 256, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + }, + }, + }, + "vocoder_config": { + "target": "audioldm.first_stage_model.vocoder", + "params": { + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "num_mels": 64, + "sampling_rate": 16000, + }, + }, + }, + }, +} + + +def load_pipeline_from_original_audioldm_ckpt( + checkpoint_path: str, + original_config_file: str = None, + image_size: int = 512, + prediction_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "ddim", + num_in_channels: int = None, + device: str = None, + from_safetensors: bool = False, +) -> AudioLDMPipeline: + """ + Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + :param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file + corresponding to the original architecture. + If `None`, will be automatically instantiated based on default values. + :param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param + prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original + AudioLDM checkpoints. + :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically + inferred. + :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", + "euler-ancestral", "dpm", "ddim"]`. + :param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract + the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually + yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. + :param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If + `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors + instead of PyTorch. + :return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + if from_safetensors: + if not is_safetensors_available(): + raise ValueError(BACKENDS_MAPPING["safetensors"][1]) + + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + original_config = DEFAULT_CONFIG + original_config = OmegaConf.create(original_config) + else: + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + prediction_type = "v_prediction" + else: + if prediction_type is None: + prediction_type = "epsilon" + + if image_size is None: + image_size = 512 + + num_train_timesteps = original_config.model.params.timesteps + beta_start = original_config.model.params.linear_start + beta_end = original_config.model.params.linear_end + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DModel + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet = UNet2DConditionModel(**unet_config) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model + vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the text model + # AudioLDM uses the same configuration and tokenizer as the original CLAP model + config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused") + tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + converted_text_model = convert_open_clap_checkpoint(checkpoint) + text_model = ClapTextModelWithProjection(config) + + missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False) + # we expect not to have token_type_ids in our original state dict so let's ignore them + missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}") + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}") + + # Convert the vocoder model + vocoder_config = create_transformers_vocoder_config(original_config) + vocoder_config = SpeechT5HifiGanConfig(**vocoder_config) + converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config) + + vocoder = SpeechT5HifiGan(vocoder_config) + vocoder.load_state_dict(converted_vocoder_checkpoint) + + # Instantiate the diffusers pipeline + pipe = AudioLDMPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--scheduler_type", + default="ddim", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=("The image size that the model was trained on."), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=("The prediction type that the model was trained on."), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + pipe = load_pipeline_from_original_audioldm_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + prediction_type=args.prediction_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + num_in_channels=args.num_in_channels, + from_safetensors=args.from_safetensors, + device=args.device, + ) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index b90737892815..de64095523b6 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -16,6 +16,8 @@ import argparse +import torch + from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt @@ -123,6 +125,7 @@ parser.add_argument( "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." ) + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() pipe = download_from_original_stable_diffusion_ckpt( @@ -143,6 +146,9 @@ controlnet=args.controlnet, ) + if args.half: + pipe.to(torch_dtype=torch.float16) + if args.controlnet: # only save the controlnet model pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py index 93eb7e6c4522..b895e08e9de9 100644 --- a/scripts/convert_versatile_diffusion_to_diffusers.py +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -19,7 +19,7 @@ import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, @@ -280,17 +280,17 @@ def create_image_unet_diffusers_config(unet_params): if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") - config = dict( - sample_size=None, - in_channels=unet_params.input_channels, - out_channels=unet_params.output_channels, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=unet_params.num_noattn_blocks[0], - cross_attention_dim=unet_params.context_dim, - attention_head_dim=unet_params.num_heads, - ) + config = { + "sample_size": None, + "in_channels": unet_params.input_channels, + "out_channels": unet_params.output_channels, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_noattn_blocks[0], + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": unet_params.num_heads, + } return config @@ -319,17 +319,17 @@ def create_text_unet_diffusers_config(unet_params): if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") - config = dict( - sample_size=None, - in_channels=(unet_params.input_channels, 1, 1), - out_channels=(unet_params.output_channels, 1, 1), - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=unet_params.num_noattn_blocks[0], - cross_attention_dim=unet_params.context_dim, - attention_head_dim=unet_params.num_heads, - ) + config = { + "sample_size": None, + "in_channels": (unet_params.input_channels, 1, 1), + "out_channels": (unet_params.output_channels, 1, 1), + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_noattn_blocks[0], + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": unet_params.num_heads, + } return config @@ -343,16 +343,16 @@ def create_vae_diffusers_config(vae_params): down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - config = dict( - sample_size=vae_params.resolution, - in_channels=vae_params.in_channels, - out_channels=vae_params.out_ch, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=vae_params.z_channels, - layers_per_block=vae_params.num_res_blocks, - ) + config = { + "sample_size": vae_params.resolution, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } return config @@ -774,7 +774,7 @@ def convert_vd_vae_checkpoint(checkpoint, config): vae.load_state_dict(converted_vae_checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") + image_feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") diff --git a/setup.py b/setup.py index cdf29df7f269..da75dd1e2a85 100644 --- a/setup.py +++ b/setup.py @@ -95,8 +95,10 @@ "Jinja2", "k-diffusion>=0.0.12", "librosa", + "note-seq", "numpy", "parameterized", + "protobuf>=3.20.3,<4", "pytest", "pytest-timeout", "pytest-xdist", @@ -182,13 +184,14 @@ def run(self): extras = {} extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder") -extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2") +extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") extras["test"] = deps_list( "compel", "datasets", "Jinja2", "k-diffusion", "librosa", + "note-seq", "parameterized", "pytest", "pytest-timeout", @@ -223,7 +226,7 @@ def run(self): setup( name="diffusers", - version="0.15.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.15.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a1e736671be7..c7d850d65953 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.15.0.dev0" +__version__ = "0.15.0" from .configuration_utils import ConfigMixin from .utils import ( @@ -8,6 +8,7 @@ is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, + is_note_seq_available, is_onnx_available, is_scipy_available, is_torch_available, @@ -37,6 +38,7 @@ ControlNetModel, ModelMixin, PriorTransformer, + T5FilmDecoder, Transformer2DModel, UNet1DModel, UNet2DConditionModel, @@ -107,9 +109,11 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .loaders import TextualInversionLoaderMixin from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, + AudioLDMPipeline, CycleDiffusionPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, @@ -123,6 +127,7 @@ StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, + StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, @@ -132,6 +137,7 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, + TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, @@ -172,12 +178,21 @@ else: from .pipelines import AudioDiffusionPipeline, Mel +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .pipelines import SpectrogramDiffusionPipeline + try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: + from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL @@ -201,7 +216,16 @@ from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) + +try: + if not (is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_note_seq_objects import * # noqa F403 +else: + from .pipelines import MidiProcessor diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 20b7b273d5af..45930431351a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -109,13 +109,6 @@ def register_to_config(self, **kwargs): # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: @@ -420,7 +413,7 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): # 0. Copy origin config dict - original_dict = {k: v for k, v in config_dict.items()} + original_dict = dict(config_dict.items()) # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) @@ -610,7 +603,7 @@ def init(self, *args, **kwargs): ) # Ignore private kwargs in the init. Retrieve all passed attributes - init_kwargs = {k: v for k, v in kwargs.items()} + init_kwargs = dict(kwargs.items()) # Retrieve default values fields = dataclasses.fields(self) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index eadc4c4adde1..1269cf1578a6 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -19,8 +19,10 @@ "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "librosa": "librosa", + "note-seq": "note-seq", "numpy": "numpy", "parameterized": "parameterized", + "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py index 7de33a795c77..e4af4986faad 100644 --- a/src/diffusers/experimental/rl/value_guided_sampling.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -52,13 +52,13 @@ def __init__( self.scheduler = scheduler self.env = env self.data = env.get_dataset() - self.means = dict() + self.means = {} for key in self.data.keys(): try: self.means[key] = self.data[key].mean() except: # noqa: E722 pass - self.stds = dict() + self.stds = {} for key in self.data.keys(): try: self.stds[key] = self.data[key].std() diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index de6543800b2d..4598e1b4288c 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -99,8 +99,8 @@ def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` """ w, h = images.size - w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor - images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) + w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor + images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) return images def preprocess( @@ -119,7 +119,7 @@ def preprocess( ) if isinstance(image[0], PIL.Image.Image): - if self.do_resize: + if self.config.do_resize: image = [self.resize(i) for i in image] image = [np.array(i).astype(np.float32) / 255.0 for i in image] image = np.stack(image, axis=0) # to np @@ -129,23 +129,27 @@ def preprocess( image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = self.numpy_to_pt(image) _, _, height, width = image.shape - if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" + f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) _, _, height, width = image.shape - if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" + f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) # expected range [0,1], normalize to [-1,1] - do_normalize = self.do_normalize + do_normalize = self.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 31fdc46d9e1b..e814981a85c9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,18 +13,29 @@ # limitations under the License. import os from collections import defaultdict -from typing import Callable, Dict, Union +from typing import Callable, Dict, List, Optional, Union import torch from .models.attention_processor import LoRAAttnProcessor -from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging +from .utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + TEXT_ENCODER_TARGET_MODULES, + _get_model_file, + deprecate, + is_safetensors_available, + is_transformers_available, + logging, +) if is_safetensors_available(): import safetensors +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + logger = logging.get_logger(__name__) @@ -32,12 +43,15 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TEXT_INVERSION_NAME = "learned_embeds.bin" +TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" + class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} + self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # we add a hook to state_dict() and load_state_dict() so that the @@ -68,12 +82,12 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict r""" Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. - This function is experimental and might change in the future. + This function is experimental and might change in the future. @@ -112,7 +126,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. @@ -120,15 +133,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). """ @@ -244,7 +250,7 @@ def save_attn_procs( ): r""" Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): @@ -292,5 +298,756 @@ def save_function(weights, filename): # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + + +class TextualInversionLoaderMixin: + r""" + Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder. + """ + + def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. + + Parameters: + prompt (`str` or list of `str`): + The prompt or prompts to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str` or list of `str`: The converted prompt + """ + if not isinstance(prompt, List): + prompts = [prompt] + else: + prompts = prompt + + prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] + + if not isinstance(prompt, List): + return prompts[0] + + return prompts + + def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): + r""" + Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds + to a multi-vector textual inversion embedding, this function will process the prompt so that the special token + is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual + inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. + + Parameters: + prompt (`str`): + The prompt to guide the image generation. + tokenizer (`PreTrainedTokenizer`): + The tokenizer responsible for encoding the prompt into input tokens. + + Returns: + `str`: The converted prompt + """ + tokens = tokenizer.tokenize(prompt) + for token in tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f"{token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def load_textual_inversion( + self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs + ): + r""" + Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and + `Automatic1111` formats are supported (see example below). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like + `"sd-concepts-library/low-poly-hd-logos-icons"`. + - A path to a *directory* containing textual inversion weights, e.g. + `./my_text_inversion_directory/`. + weight_name (`str`, *optional*): + Name of a custom weight file. This should be used in two cases: + + - The saved textual inversion file is in `diffusers` format, but was saved under a specific weight + name, such as `text_inv.bin`. + - The saved textual inversion file is in the "Automatic1111" form. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + Example: + + To load a textual inversion embedding vector in `diffusers` format: + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("sd-concepts-library/cat-toy") + + prompt = "A backpack" + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("cat-backpack.png") + ``` + + To load a textual inversion embedding vector in Automatic1111 format, make sure to first download the vector, + e.g. from [civitAI](https://civitai.com/models/3036?modelVersionId=9857) and then load the vector locally: + + ```py + from diffusers import StableDiffusionPipeline + import torch + + model_id = "runwayml/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + pipe.load_textual_inversion("./charturnerv2.pt") + + prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." + + image = pipe(prompt, num_inference_steps=50).images[0] + image.save("character.png") + ``` + """ + if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): + raise ValueError( + f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): + raise ValueError( + f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" + f" `{self.load_textual_inversion.__name__}`" + ) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "text_inversion", + "framework": "pytorch", + } + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + + # 2. Load token and embedding correcly from file + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + + if token is not None and loaded_token != token: + logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token in vocab: + raise ValueError( + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." + ) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 + + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) + + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + + if is_multi_vector: + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + embeddings = [e for e in embedding] # noqa: C416 + else: + tokens = [token] + embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] + + # add tokens and get ids + self.tokenizer.add_tokens(tokens) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + + # resize token embeddings and set new embeddings + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + for token_id, embedding in zip(token_ids, embeddings): + self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + + logger.info(f"Loaded textual inversion embedding for {token}.") + + +class LoraLoaderMixin: + r""" + Utility class for handling the loading LoRA layers into UNet (of class [`UNet2DConditionModel`]) and Text Encoder + (of class [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + """ + text_encoder_name = "text_encoder" + unet_name = "unet" + + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers (such as LoRA) into [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + + # Load the layers corresponding to UNet. + if all(key.startswith(self.unet_name) for key in keys): + logger.info(f"Loading {self.unet_name}.") + unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} + self.unet.load_attn_procs(unet_lora_state_dict) + + # Load the layers corresponding to text encoder and make necessary adjustments. + elif all(key.startswith(self.text_encoder_name) for key in keys): + logger.info(f"Loading {self.text_encoder_name}.") + text_encoder_lora_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) + } + attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) + self._modify_text_encoder(attn_procs_text_encoder) + + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. + elif not all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ): + self.unet.load_attn_procs(state_dict) + deprecation_message = "You have saved the LoRA weights using the old format. This will be" + " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" + " in a dictionary and then create a new dictionary like the following:" + " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + + Parameters: + attn_processors: Dict[str, `LoRAAttnProcessor`]: + A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. + """ + # Loop over the original attention modules. + for name, _ in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + # Construct a new function that performs the LoRA merging. We will monkey patch + # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward + + def new_forward(x): + return old_forward(x) + lora_layer(x) + + # Monkey-patch. + module.forward = new_forward + + def _get_lora_layer_attribute(self, name: str) -> str: + if "q_proj" in name: + return "to_q_lora" + elif "v_proj" in name: + return "to_v_lora" + elif "k_proj" in name: + return "to_k_lora" + else: + return "to_out_lora" + + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers for + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + Returns: + `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding + [`LoRAAttnProcessor`]. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = { + k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() + } + return attn_processors + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, torch.nn.Module] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + ): + r""" + Save the LoRA parameters corresponding to the UNet and the text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the + serialization process easier and cleaner. + text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from + `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state + dict. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + # Create a flat dictionary. + state_dict = {} + if unet_lora_layers is not None: + unet_lora_state_dict = { + f"{self.unet_name}.{module_name}": param + for module_name, param in unet_lora_layers.state_dict().items() + } + state_dict.update(unet_lora_state_dict) + if text_encoder_lora_layers is not None: + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param + for module_name, param in text_encoder_lora_layers.state_dict().items() + } + state_dict.update(text_encoder_lora_state_dict) + + # Save the model + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 752aeb409f57..23839c84af45 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,6 +21,7 @@ from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer + from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel @@ -29,5 +30,6 @@ from .vq_model import VQModel if is_flax_available(): + from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f271e00f8639..5538a7b8249d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -224,7 +224,14 @@ def __init__( f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) + # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -235,10 +242,16 @@ def __init__( upcast_attention=upcast_attention, ) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, @@ -248,30 +261,13 @@ def __init__( bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = None - - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) else: self.norm2 = None + self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, @@ -283,6 +279,8 @@ def forward( cross_attention_kwargs=None, class_labels=None, ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: @@ -292,7 +290,6 @@ def forward( else: norm_hidden_states = self.norm1(hidden_states) - # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, @@ -304,6 +301,7 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states + # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) @@ -311,7 +309,6 @@ def forward( # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here - # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1a47d728c2f9..4f78b324a8e2 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -12,10 +12,110 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import math + import flax.linen as nn +import jax import jax.numpy as jnp +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] + ) + + # julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + + +def jax_memory_efficient_attention( + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): + r""" + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + # julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] + ) + + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), + ) + + _, res = jax.lax.scan( + f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + ) + + return jnp.concatenate(res, axis=-3) # fuse the chunked result back + + class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -29,6 +129,8 @@ class FlaxAttention(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` @@ -37,6 +139,7 @@ class FlaxAttention(nn.Module): heads: int = 8 dim_head: int = 64 dropout: float = 0.0 + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) - # compute attentions - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=2) + if self.use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) + else: + query_chunk_size = int(flatten_latent_dim) + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) + + hidden_states = hidden_states.transpose(1, 0, 2) + else: + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return hidden_states @@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int n_heads: int @@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module): dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) # cross attention - self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module): only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int n_heads: int @@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) @@ -202,6 +340,7 @@ def setup(self): dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) for _ in range(self.depth) ] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30026cd89ff9..f2a5a376bf39 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -56,11 +56,13 @@ def __init__( bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, + only_cross_attention: bool = False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -68,7 +70,6 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -79,22 +80,54 @@ def __init__( self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) else: self.group_norm = None - if cross_attention_norm: + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) @@ -222,11 +255,15 @@ def batch_to_head_dim(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def head_to_batch_dim(self, tensor): + def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor def get_attention_scores(self, query, key, attention_mask=None): @@ -260,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None): return attention_probs - def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", @@ -287,10 +324,34 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) else: attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states): + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + class AttnProcessor: def __call__( @@ -308,8 +369,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -375,7 +436,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -400,27 +464,34 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) @@ -437,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class AttnAddedKVProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op @@ -452,8 +581,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -496,8 +625,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -546,7 +675,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) @@ -583,8 +715,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -627,30 +759,38 @@ def __init__(self, slice_size): def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( @@ -686,10 +826,12 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttentionProcessor = Union[ AttnProcessor, + AttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, LoRAAttnProcessor, LoRAXFormersAttnProcessor, ] diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 8f65c2357cac..5d1c54a9af25 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, apply_forward_hook +from ..utils import BaseOutput, apply_forward_hook, deprecate from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -120,9 +120,19 @@ def __init__( if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) - self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 + @property + def block_out_channels(self): + deprecate( + "block_out_channels", + "1.0.0", + "Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`", + standard_warn=False, + ) + return self.config.block_out_channels + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index ac6e64e4c779..bb608ad82a7a 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -368,6 +368,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py new file mode 100644 index 000000000000..3adefa84ea68 --- /dev/null +++ b/src/diffusers/models/controlnet_flax.py @@ -0,0 +1,383 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, +) + + +@flax.struct.dataclass +class FlaxControlNetOutput(BaseOutput): + down_block_res_samples: jnp.ndarray + mid_block_res_sample: jnp.ndarray + + +class FlaxControlNetConditioningEmbedding(nn.Module): + conditioning_embedding_channels: int + block_out_channels: Tuple[int] = (16, 32, 96, 256) + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_in = nn.Conv( + self.block_out_channels[0], + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + blocks = [] + for i in range(len(self.block_out_channels) - 1): + channel_in = self.block_out_channels[i] + channel_out = self.block_out_channels[i + 1] + conv1 = nn.Conv( + channel_in, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv1) + conv2 = nn.Conv( + channel_out, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv2) + self.blocks = blocks + + self.conv_out = nn.Conv( + self.conditioning_embedding_channels, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = nn.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = nn.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +@flax_register_to_config +class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 Γ— 512 images into smaller 64 Γ— 64 β€œlatent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 Γ— 64 feature space to match the + convolution size. We use a tiny network E(Β·) of four convolution layers with 4 Γ— 4 kernels and 2 Γ— 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): + The channel order of conditional image. Will convert it to `rgb` if it's `bgr` + conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in conditioning_embedding layer + + + """ + sample_size: int = 32 + in_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + only_cross_attention: Union[bool, Tuple[bool]] = False + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int]] = 8 + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + controlnet_conditioning_channel_order: str = "rgb" + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) + + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) + controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=self.conditioning_embedding_out_channels, + ) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + attention_head_dim = self.attention_head_dim + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(self.down_block_types) + + # down + down_blocks = [] + controlnet_down_blocks = [] + + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=attention_head_dim[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + + for _ in range(self.layers_per_block): + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + self.down_blocks = down_blocks + self.controlnet_down_blocks = controlnet_down_blocks + + # mid + mid_block_channel = block_out_channels[-1] + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + dropout=self.dropout, + attn_num_head_channels=attention_head_dim[-1], + use_linear_projection=self.use_linear_projection, + dtype=self.dtype, + ) + + self.controlnet_mid_block = nn.Conv( + mid_block_channel, + kernel_size=(1, 1), + padding="VALID", + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + controlnet_cond, + conditioning_scale: float = 1.0, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxControlNetOutput, Tuple]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor + conditioning_scale: (`float`) the scale factor for controlnet outputs + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + channel_order = self.controlnet_conditioning_channel_order + if channel_order == "bgr": + controlnet_cond = jnp.flip(controlnet_cond, axis=1) + + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. contronet blocks + controlnet_down_block_res_samples = () + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return FlaxControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e51b40ce4509..6a849f6f0e45 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -16,27 +16,22 @@ import inspect import os -import warnings from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from packaging import version -from requests import HTTPError from torch import Tensor, device from .. import __version__ from ..utils import ( CONFIG_NAME, - DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, + _add_variant, + _get_model_file, is_accelerate_available, is_safetensors_available, is_torch_version, @@ -144,15 +139,6 @@ def load(module: torch.nn.Module, prefix=""): return error_msgs -def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None: - splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] - weights_name = ".".join(splits) - - return weights_name - - class ModelMixin(torch.nn.Module): r""" Base class for all models. @@ -579,10 +565,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " those weights or else make sure your checkpoint file is correct." ) + empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() ) + + if empty_state_dict[param_name].shape != param.shape: + raise ValueError( + f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + if accepts_dtype: set_module_tensor_to_device( model, param_name, param_device, value=param, dtype=torch_dtype @@ -647,7 +640,7 @@ def _load_pretrained_model( ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() - loaded_keys = [k for k in state_dict.keys()] + loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) @@ -782,121 +775,3 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - - -def _get_model_file( - pretrained_model_name_or_path, - *, - weights_name, - subfolder, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - user_agent, - revision, - commit_hash=None, -): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): - return pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, weights_name) - return model_file - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - return model_file - else: - raise EnvironmentError( - f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." - ) - else: - # 1. First check if deprecated way of loading from branches is used - if ( - revision in DEPRECATED_REVISION_ARGS - and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) - and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") - ): - try: - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=_add_variant(weights_name, revision), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision or commit_hash, - ) - warnings.warn( - f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - return model_file - except: # noqa: E722 - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", - FutureWarning, - ) - try: - # 2. Load model file as usual - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision or commit_hash, - ) - return model_file - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." - ) - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {weights_name} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {weights_name}" - ) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 98f8f19c896a..d9d539959c09 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -459,6 +459,7 @@ def __init__( pre_norm=True, eps=1e-6, non_linearity="swish", + skip_time_act=False, time_embedding_norm="default", # default, scale_shift, ada_group kernel=None, output_scale_factor=1.0, @@ -479,6 +480,7 @@ def __init__( self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups @@ -570,7 +572,9 @@ def forward(self, input_tensor, temb): hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb diff --git a/src/diffusers/models/t5_film_transformer.py b/src/diffusers/models/t5_film_transformer.py new file mode 100644 index 000000000000..1c41e656a9db --- /dev/null +++ b/src/diffusers/models/t5_film_transformer.py @@ -0,0 +1,321 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from .attention_processor import Attention +from .embeddings import get_timestep_embedding +from .modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.Sequential( + nn.Linear(d_model, d_model * 4, bias=False), + nn.SiLU(), + nn.Linear(d_model * 4, d_model * 4, bias=False), + nn.SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = nn.ModuleList() + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Linear(d_model, input_dims, bias=False) + + def encoder_decoder_mask(self, query_input, key_input): + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config.max_decoder_noise_time, + embedding_dim=self.config.d_model, + max_period=self.config.max_decoder_noise_time, + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = torch.broadcast_to( + torch.arange(seq_length, device=decoder_input_tokens.device), + (batch, seq_length), + ) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = torch.ones( + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype + ) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Module): + def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): + super().__init__() + self.layer = nn.ModuleList() + + # cond self attention: layer 0 + self.layer.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + self.layer.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + self.layer.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + def forward( + self, + hidden_states, + conditioning_emb=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + ): + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( + encoder_hidden_states.dtype + ) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Module): + def __init__(self, d_model, d_kv, num_heads, dropout_rate): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states, + conditioning_emb=None, + attention_mask=None, + ): + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states, + key_value_states=None, + attention_mask=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Module): + def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, hidden_states, conditioning_emb=None): + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, d_model, d_ff, dropout_rate): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) + self.wo = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout_rate) + self.act = NewGELUActivation() + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class T5FiLMLayer(nn.Module): + """ + FiLM Layer + """ + + def __init__(self, in_features, out_features): + super().__init__() + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) + + def forward(self, x, conditioning_emb): + emb = self.scale_bias(conditioning_emb) + scale, shift = torch.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 5062295fc668..c7755bb3ed45 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -19,7 +19,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @@ -47,6 +47,9 @@ class UNet1DModel(ModelMixin, ConfigMixin): sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model is initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : @@ -187,6 +190,16 @@ def __init__( fc_dim=block_out_channels[-1] // 4, ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..a83e4917c143 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -44,7 +44,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. @@ -215,6 +216,16 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..439c5c34b601 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,10 +15,11 @@ import numpy as np import torch +import torch.nn.functional as F from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .attention_processor import Attention, AttnAddedKVProcessor +from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -42,6 +43,9 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -68,6 +72,8 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": return AttnDownBlock2D( @@ -119,6 +125,10 @@ def get_down_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -214,6 +224,9 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -241,6 +254,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -279,6 +294,10 @@ def get_up_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -562,6 +581,9 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -585,11 +607,16 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=in_channels, @@ -600,7 +627,9 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=AttnAddedKVProcessor(), + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, ) ) resnets.append( @@ -615,6 +644,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -1247,6 +1277,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, + skip_time_act=False, ): super().__init__() resnets = [] @@ -1265,6 +1296,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -1284,6 +1316,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, down=True, ) ] @@ -1337,6 +1370,9 @@ def __init__( cross_attention_dim=1280, output_scale_factor=1.0, add_downsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1362,8 +1398,14 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=out_channels, @@ -1374,7 +1416,9 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=AttnAddedKVProcessor(), + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, ) ) self.attentions = nn.ModuleList(attentions) @@ -1394,6 +1438,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, down=True, ) ] @@ -1553,7 +1598,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) @@ -2237,6 +2282,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + skip_time_act=False, ): super().__init__() resnets = [] @@ -2257,6 +2303,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -2276,6 +2323,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, up=True, ) ] @@ -2329,6 +2377,9 @@ def __init__( cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() resnets = [] @@ -2355,8 +2406,14 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=out_channels, @@ -2367,7 +2424,9 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=AttnAddedKVProcessor(), + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, ) ) self.attentions = nn.ModuleList(attentions) @@ -2387,6 +2446,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, up=True, ) ] @@ -2573,7 +2633,7 @@ def __init__( temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, - cross_attention_norm=True, + cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) @@ -2668,7 +2728,7 @@ def __init__( upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, - cross_attention_norm: bool = False, + cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() @@ -2684,7 +2744,7 @@ def __init__( dropout=dropout, bias=attention_bias, cross_attention_dim=None, - cross_attention_norm=False, + cross_attention_norm=None, ) # 2. Cross-Attn diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 8e9690d332c9..b8126c5f5930 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -72,6 +75,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -209,6 +216,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 use_linear_projection: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -341,6 +352,7 @@ def setup(self): d_head=self.in_channels // self.attn_num_head_channels, depth=1, use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 79a361763c76..1b982aedc5de 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,12 +16,13 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from ..utils import BaseOutput, deprecate, logging +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -86,18 +87,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, it will skip the normalization and activation layers in post-processing norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + encoder_hid_dim (`int`, *optional*, defaults to None): + If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, or `"projection"`. + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -106,6 +113,13 @@ class conditioning with `class_embed_type` equal to `None`. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the + `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will + default to `False`. """ _supports_gradient_checkpointing = True @@ -129,13 +143,14 @@ def __init__( up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + encoder_hid_dim: Optional[int] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -143,12 +158,18 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -175,6 +196,16 @@ def __init__( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( @@ -208,6 +239,11 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) + if encoder_hid_dim is not None: + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -228,18 +264,57 @@ def __init__( # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None + if time_embedding_act_fn is None: + self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -249,15 +324,15 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[i], attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, @@ -265,6 +340,9 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -272,12 +350,12 @@ def __init__( if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, @@ -287,14 +365,17 @@ def __init__( elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -307,6 +388,8 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -326,22 +409,25 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=reversed_cross_attention_dim[i], attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -361,6 +447,16 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" @@ -415,6 +511,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -571,7 +673,17 @@ def forward( class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None: + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index a40473a25f55..3c2f4a88ab7f 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ @@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 + use_memory_efficient_attention: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -169,6 +172,7 @@ def setup(self): add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: @@ -190,6 +194,7 @@ def setup(self): dropout=self.dropout, attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) @@ -217,6 +222,7 @@ def setup(self): dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: @@ -249,6 +255,8 @@ def __call__( sample, timesteps, encoder_hidden_states, + down_block_additional_residuals=None, + mid_block_additional_residual=None, return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: @@ -291,9 +299,23 @@ def __call__( sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 9f8ee2a22aab..2c86171610bf 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -251,7 +251,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -376,7 +378,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample output_states += (hidden_states,) @@ -587,7 +591,9 @@ def forward( encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ).sample - hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + hidden_states = temp_attn( + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 8006d0e1c127..6fb5dfa30ebf 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -20,8 +20,9 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging -from .attention_processor import AttentionProcessor +from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel @@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output. @@ -372,6 +373,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value @@ -458,7 +466,9 @@ def forward( sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) - sample = self.transformer_in(sample, num_frames=num_frames).sample + sample = self.transformer_in( + sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample # 3. down down_block_res_samples = (sample,) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index d7f923b49690..657e085062e0 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -242,6 +242,7 @@ def get_scheduler( num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, + last_epoch: int = -1, ): """ Unified API to get any scheduler from its name. @@ -267,14 +268,14 @@ def get_scheduler( name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + return schedule_func(optimizer, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) # All other schedulers require `num_training_steps` if num_training_steps is None: @@ -282,12 +283,22 @@ def get_scheduler( if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + last_epoch=last_epoch, ) if name == SchedulerType.POLYNOMIAL: return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + power=power, + last_epoch=last_epoch, ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch + ) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 07f5601ee917..7562040596e9 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -7,9 +7,9 @@ components - all of which are needed to have a functioning end-to-end diffusion As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: - [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392) - [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12) -- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPTextModel) +- [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel) - a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py), -- a [CLIPFeatureExtractor](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPFeatureExtractor), +- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor), - as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py). All of these components are necessary to run stable diffusion in inference even though they were trained or created independently from each other. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 87d1a6997e59..602cf028e2e9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,6 +3,7 @@ is_flax_available, is_k_diffusion_available, is_librosa_available, + is_note_seq_available, is_onnx_available, is_torch_available, is_transformers_available, @@ -42,6 +43,7 @@ from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .audioldm import AudioLDMPipeline from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline @@ -56,6 +58,7 @@ StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, + StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionPipeline, StableDiffusionPix2PixZeroPipeline, @@ -65,7 +68,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .text_to_video_synthesis import TextToVideoSDPipeline + from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, @@ -122,7 +125,15 @@ from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .stable_diffusion import ( + FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 1ae82beb54a4..bf314b91116e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -17,11 +17,12 @@ import torch from packaging import version -from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring @@ -49,7 +50,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline): +class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -73,7 +74,7 @@ class AltDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -86,7 +87,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -294,8 +295,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -312,6 +313,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -372,6 +377,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -551,8 +560,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -637,7 +646,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index b71217a4b3ec..bb8116f2f5d5 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -13,18 +13,19 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring @@ -74,7 +75,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -88,7 +89,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionImg2ImgPipeline(DiffusionPipeline): +class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Alt Diffusion. @@ -112,7 +113,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -125,7 +126,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -304,8 +305,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -322,6 +323,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -382,6 +387,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -569,6 +578,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -626,6 +636,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: @@ -687,7 +701,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -715,14 +734,15 @@ def __call__( image = latents has_nsfw_concept = None - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: - has_nsfw_concept = False + image = self.decode_latents(latents) + + if self.safety_checker is not None: + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + has_nsfw_concept = False - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py index 8f0925ac4aaa..1df76ed6c52c 100644 --- a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -60,9 +60,9 @@ def get_input_dims(self) -> Tuple: input_module = self.vqvae if self.vqvae is not None else self.unet # For backwards compatibility sample_size = ( - (input_module.sample_size, input_module.sample_size) - if type(input_module.sample_size) == int - else input_module.sample_size + (input_module.config.sample_size, input_module.config.sample_size) + if type(input_module.config.sample_size) == int + else input_module.config.sample_size ) return sample_size @@ -121,17 +121,17 @@ def __call__( self.scheduler.set_timesteps(steps) step_generator = step_generator or generator # For backwards compatibility - if type(self.unet.sample_size) == int: - self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size) + if type(self.unet.config.sample_size) == int: + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) input_dims = self.get_input_dims() self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) if noise is None: noise = randn_tensor( ( batch_size, - self.unet.in_channels, - self.unet.sample_size[0], - self.unet.sample_size[1], + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], ), generator=generator, device=self.device, @@ -158,7 +158,7 @@ def __call__( images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) pixels_per_second = ( - self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length ) mask_start = int(mask_start_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second) @@ -201,12 +201,12 @@ def __call__( images = images.cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") images = list( - map(lambda _: Image.fromarray(_[:, :, 0]), images) + (Image.fromarray(_[:, :, 0]) for _ in images) if images.shape[3] == 1 - else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images) + else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) ) - audios = list(map(lambda _: self.mel.image_to_audio(_), images)) + audios = [self.mel.image_to_audio(_) for _ in images] if not return_dict: return images, (self.mel.get_sample_rate(), audios) diff --git a/src/diffusers/pipelines/audioldm/__init__.py b/src/diffusers/pipelines/audioldm/__init__.py new file mode 100644 index 000000000000..8ddef6c3f325 --- /dev/null +++ b/src/diffusers/pipelines/audioldm/__init__.py @@ -0,0 +1,17 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) +else: + from .pipeline_audioldm import AudioLDMPipeline diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py new file mode 100644 index 000000000000..86a8fd659046 --- /dev/null +++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -0,0 +1,601 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AudioLDMPipeline + + >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A hammer hitting a wooden surface" + >>> audio = pipe(prompt).audio[0] + ``` +""" + + +class AudioLDMPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using AudioLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations. + text_encoder ([`ClapTextModelWithProjection`]): + Frozen text-encoder. AudioLDM uses the text portion of + [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection), + specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + Tokenizer of class + [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer). + unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`SpeechT5HifiGan`]): + Vocoder of class + [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan). + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection, + tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and vocoder have their state dicts saved to CPU and then are moved to a `torch.device('meta') + and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.vocoder]: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + return mel_spectrogram + + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu() + return waveform + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + self.vocoder.config.model_in_dim // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_length_in_s: Optional[float] = None, + num_inference_steps: int = 10, + guidance_scale: float = 2.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + output_type: Optional[str] = "np", + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`. + instead. + audio_length_in_s (`int`, *optional*, defaults to 5.12): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 10): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`, + usually at the expense of lower sound quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate image. Choose between: + - `"np"`: Return Numpy `np.ndarray` objects. + - `"pt"`: Return PyTorch `torch.Tensor` objects. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated audios. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + mel_spectrogram = self.decode_latents(latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index 018e020491ce..1bfed086e8c6 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -61,7 +61,7 @@ def __call__( to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* - `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. + `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. @@ -73,27 +73,29 @@ def __call__( if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate - sample_size = audio_length_in_s * self.unet.sample_rate + sample_size = audio_length_in_s * self.unet.config.sample_rate down_scale_factor = 2 ** len(self.unet.up_blocks) if sample_size < 3 * down_scale_factor: raise ValueError( f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" - f" {3 * down_scale_factor / self.unet.sample_rate}." + f" {3 * down_scale_factor / self.unet.config.sample_rate}." ) original_sample_size = int(sample_size) if sample_size % down_scale_factor != 0: - sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor + sample_size = ( + (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 + ) * down_scale_factor logger.info( - f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" - f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" + f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" " process." ) sample_size = int(sample_size) dtype = next(iter(self.unet.parameters())).dtype - shape = (batch_size, self.unet.in_channels, sample_size) + shape = (batch_size, self.unet.config.in_channels, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 0e7f2258fa99..aaf53589b969 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -79,10 +79,15 @@ def __call__( """ # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 549dbb29d5e7..b4290daf852c 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -67,10 +67,15 @@ def __call__( True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if self.device.type == "mps": # randn does not work reproducibly on mps diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 623b456e52b5..3e4f9425b0f6 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -135,7 +135,7 @@ def __call__( prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index 2ecf5f24a4a7..ae620d325307 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -21,7 +21,7 @@ def preprocess(image): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -112,7 +112,7 @@ def __call__( height, width = image.shape[-2:] # in_channels should be 6: 3 for latents, 3 for low resolution image - latents_shape = (batch_size, self.unet.in_channels // 2, height, width) + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) latents_dtype = next(self.unet.parameters()).dtype latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index dc0200feedb1..73c607a27187 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -73,7 +73,7 @@ def __call__( """ latents = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) latents = latents.to(self.device) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 353805228671..ca0a90a5b5ca 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -18,7 +18,7 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor +from transformers import CLIPImageProcessor from diffusers.utils import is_accelerate_available @@ -156,7 +156,7 @@ class PaintByExamplePipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode initial images (if they are in PIL format), @@ -170,7 +170,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = False, ): super().__init__() diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 30e32c3d66e9..6ab0b80ee655 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -278,7 +278,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" - >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( + >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) @@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) + use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -365,7 +366,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): - loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + loaded_sub_model, loaded_params = load_method( + loadable_folder, + from_pt=from_pt, + use_memory_efficient_attention=use_memory_efficient_attention, + dtype=dtype, + ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: @@ -470,6 +476,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + model = pipeline_class(**init_kwargs, dtype=dtype) return model, params @@ -478,7 +497,7 @@ def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - set(["self"]) + expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f33b506827a..2e20c21aaf38 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,6 +50,7 @@ get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, @@ -133,7 +134,7 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray -def is_safetensors_compatible(filenames, variant=None) -> bool: +def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch @@ -149,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: sf_filenames = set() + passed_components = passed_components or [] + for filename in filenames: _, extension = os.path.splitext(filename) + if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: + continue + if extension == ".bin": pt_filenames.append(filename) elif extension == ".safetensors": @@ -162,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) - if filename == "pytorch_model": - filename = "model" - elif filename == f"pytorch_model.{variant}": - filename = f"model.{variant}" + if filename.startswith("pytorch_model"): + filename = filename.replace("pytorch_model", "model") else: filename = filename @@ -195,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = "\d{5}-of-\d{5}" + + if variant is not None: + # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors` + variant_file_re = re.compile( + f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + # `text_encoder/pytorch_model.bin.index.fp16.json` + variant_index_re = re.compile( + f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) - variant_file_regex = ( - re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})") - if variant is not None - else None + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors` + non_variant_file_re = re.compile( + f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) - non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}") + # `text_encoder/pytorch_model.bin.index.json` + non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: - variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None) + variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} + variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} + variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() - non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None) + non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} + non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} + non_variant_filenames = non_variant_weights | non_variant_indexes + # all variant filenames will be used by default usable_filenames = set(variant_filenames) + + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + for f in non_variant_filenames: - variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}" + variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) @@ -225,7 +256,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, use_auth_token=use_auth_token, revision=None, ) - filenames = set(sibling.rfilename for sibling in info.siblings) + filenames = {sibling.rfilename for sibling in info.siblings} comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] @@ -255,7 +286,14 @@ def maybe_raise_or_warn( if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + model_cls = sub_model.__class__ + if is_compiled_module(sub_model): + model_cls = sub_model._orig_mod.__class__ + + if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" @@ -284,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p return class_obj, class_candidates +def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None): + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + return get_class_from_dynamic_module( + custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision + ) + + if class_obj != DiffusionPipeline: + return class_obj + + diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) + return getattr(diffusers_module, config["_class_name"]) + + def load_sub_model( library_name: str, class_name: str, @@ -419,6 +478,10 @@ def register_modules(self, **kwargs): if module is None: register_dict = {name: (None, None)} else: + # register the original module, not the dynamo compiled one + if is_compiled_module(module): + module = module._orig_mod + library = module.__module__.split(".")[0] # check if the module is a pipeline module @@ -443,6 +506,21 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) + def __setattr__(self, name: str, value: Any): + if hasattr(self, name) and hasattr(self.config, name): + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if value is not None and self.config[name][0] is not None: + class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) + else: + class_library_tuple = (None, None) + + self.register_to_config(**{name: class_library_tuple}) + else: + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -484,6 +562,12 @@ def is_saveable_module(name, value): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = sub_model._orig_mod + model_cls = sub_model.__class__ + save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): @@ -550,9 +634,11 @@ def module_is_offloaded(module): f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) - module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names, _ = self._get_signature_keys(self) + module_names = [m for m in module_names if hasattr(self, m)] + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for name in module_names.keys(): + for name in module_names: module = getattr(self, name) if isinstance(module, torch.nn.Module): module.to(torch_device, torch_dtype) @@ -577,8 +663,10 @@ def device(self) -> torch.device: Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): + module_names, _ = self._get_signature_keys(self) + module_names = [m for m in module_names if hasattr(self, m)] + + for name in module_names: module = getattr(self, name) if isinstance(module, torch.nn.Module): return module.device @@ -761,7 +849,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) - kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -776,8 +864,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, from_flax=from_flax, + use_safetensors=use_safetensors, custom_pipeline=custom_pipeline, + custom_revision=custom_revision, variant=variant, + **kwargs, ) else: cached_folder = pretrained_model_name_or_path @@ -792,29 +883,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for folder in os.listdir(cached_folder): folder_path = os.path.join(cached_folder, folder) is_folder = os.path.isdir(folder_path) and folder in config_dict - variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path)) + variant_exists = is_folder and any( + p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) + ) if variant_exists: model_variants[folder] = variant # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it - if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision - ) - elif cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( @@ -1077,6 +1156,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) @@ -1115,7 +1195,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] - filenames = set(sibling.rfilename for sibling in info.siblings) + filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) # if the whole pipeline is cached we don't have to ping the Hub @@ -1126,7 +1206,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: pretrained_model_name, use_auth_token, variant, revision, model_filenames ) - model_folder_names = set([os.path.split(f)[0] for f in model_filenames]) + model_folder_names = {os.path.split(f)[0] for f in model_filenames} # all filenames compatible with variant will be added allow_patterns = list(model_filenames) @@ -1135,7 +1215,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # this enables downloading schedulers, tokenizers, ... allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model - allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names] + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [ SCHEDULER_CONFIG_NAME, @@ -1144,21 +1224,32 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: CUSTOM_PIPELINE_FILE_NAME, ] + # retrieve passed components that should not be downloaded + pipeline_class = _get_pipeline_class( + cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision + ) + expected_components, _ = cls._get_signature_keys(pipeline_class) + passed_components = [k for k in expected_components if k in kwargs] + if ( use_safetensors and not allow_pickle - and not is_safetensors_compatible(model_filenames, variant=variant) + and not is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ) ): raise EnvironmentError( f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant): + elif use_safetensors and is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ): ignore_patterns = ["*.bin", "*.msgpack"] - safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) - safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")]) + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames @@ -1169,13 +1260,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: else: ignore_patterns = ["*.safetensors", "*.msgpack"] - bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")]) - bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")]) + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: logger.warn( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) + # Don't download any objects that are passed + allow_patterns = [ + p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) + ] + # Don't download index files of forbidden patterns either + ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] @@ -1215,7 +1313,7 @@ def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - set(["self"]) + expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property @@ -1341,6 +1439,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): fn_recursive_set_mem_eff(child) module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names = [m for m in module_names if hasattr(self, m)] + for module_name in module_names: module = getattr(self, module_name) if isinstance(module, torch.nn.Module): @@ -1372,6 +1472,8 @@ def disable_attention_slicing(self): def set_attention_slice(self, slice_size: Optional[int]): module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names = [m for m in module_names if hasattr(self, m)] + for module_name in module_names: module = getattr(self, module_name) if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 56fb72d3f4ff..361444079311 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -77,7 +77,7 @@ def __call__( # Sample gaussian noise to begin loop image = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, device=self.device, ) diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index fabcd2610f43..f4914c46db51 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -37,7 +37,7 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -58,7 +58,7 @@ def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(mask[0], PIL.Image.Image): w, h = mask[0].size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] mask = np.concatenate(mask, axis=0) mask = mask.astype(np.float32) / 255.0 diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index a421a844c329..3d5374875d12 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline @@ -84,7 +84,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): safety_checker ([`Q16SafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -98,7 +98,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -476,7 +476,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/spectrogram_diffusion/__init__.py b/src/diffusers/pipelines/spectrogram_diffusion/__init__.py new file mode 100644 index 000000000000..05b14a857630 --- /dev/null +++ b/src/diffusers/pipelines/spectrogram_diffusion/__init__.py @@ -0,0 +1,26 @@ +# flake8: noqa +from ...utils import is_note_seq_available, is_transformers_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .notes_encoder import SpectrogramNotesEncoder + from .continous_encoder import SpectrogramContEncoder + from .pipeline_spectrogram_diffusion import ( + SpectrogramContEncoder, + SpectrogramDiffusionPipeline, + T5FilmDecoder, + ) + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 +else: + from .midi_utils import MidiProcessor diff --git a/src/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py b/src/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py new file mode 100644 index 000000000000..556136d4023d --- /dev/null +++ b/src/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5Config, + T5LayerNorm, +) + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + input_dims: int, + targets_context_length: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.input_proj = nn.Linear(input_dims, d_model, bias=False) + + self.position_encoding = nn.Embedding(targets_context_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + feed_forward_proj=feed_forward_proj, + dropout_rate=dropout_rate, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_inputs, encoder_inputs_mask): + x = self.input_proj(encoder_inputs) + + # terminal relative positional encodings + max_positions = encoder_inputs.shape[1] + input_positions = torch.arange(max_positions, device=encoder_inputs.device) + + seq_lens = encoder_inputs_mask.sum(-1) + input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) + x += self.position_encoding(input_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_inputs.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/src/diffusers/pipelines/spectrogram_diffusion/midi_utils.py b/src/diffusers/pipelines/spectrogram_diffusion/midi_utils.py new file mode 100644 index 000000000000..08d0878db588 --- /dev/null +++ b/src/diffusers/pipelines/spectrogram_diffusion/midi_utils.py @@ -0,0 +1,667 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +import os +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ...utils import is_note_seq_available +from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH + + +if is_note_seq_available(): + import note_seq +else: + raise ImportError("Please install note-seq via `pip install note-seq`") + + +INPUT_FEATURE_LENGTH = 2048 + +SAMPLE_RATE = 16000 +HOP_SIZE = 320 +FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) + +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 1 + +SLAKH_CLASS_PROGRAMS = { + "Acoustic Piano": 0, + "Electric Piano": 4, + "Chromatic Percussion": 8, + "Organ": 16, + "Acoustic Guitar": 24, + "Clean Electric Guitar": 26, + "Distorted Electric Guitar": 29, + "Acoustic Bass": 32, + "Electric Bass": 33, + "Violin": 40, + "Viola": 41, + "Cello": 42, + "Contrabass": 43, + "Orchestral Harp": 46, + "Timpani": 47, + "String Ensemble": 48, + "Synth Strings": 50, + "Choir and Voice": 52, + "Orchestral Hit": 55, + "Trumpet": 56, + "Trombone": 57, + "Tuba": 58, + "French Horn": 60, + "Brass Section": 61, + "Soprano/Alto Sax": 64, + "Tenor Sax": 66, + "Baritone Sax": 67, + "Oboe": 68, + "English Horn": 69, + "Bassoon": 70, + "Clarinet": 71, + "Pipe": 73, + "Synth Lead": 80, + "Synth Pad": 88, +} + + +@dataclasses.dataclass +class NoteRepresentationConfig: + """Configuration note representations.""" + + onsets_only: bool + include_ties: bool + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: Optional[int] = None + program: Optional[int] = None + is_drum: Optional[bool] = None + instrument: Optional[int] = None + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + + # velocity bin for active pitches and programs + active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Tokenizer: + def __init__(self, regular_ids: int): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + + def encode(self, token_ids): + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" + ) + encoded.append(token_id + self._num_special_tokens) + + # Add EOS token + encoded.append(1) + + # Pad to till INPUT_FEATURE_LENGTH + encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) + + return encoded + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from + Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not + include things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required + and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"Event value {event.value} is not within valid range " + f"[{er.min_value}, {er.max_value}] for type {event.type}" + ) + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event.type}") + + def event_type_range(self, event_type: str) -> Tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event_type}") + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event index: {index}") + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +def drop_programs(tokens, codec: Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range("program") + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range("program") + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), + # map each program to the first program in its MIDI class + "midi_class": ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) + ), + # leave programs as is + "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), +} + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): + """ + equivalent of tf.signal.frame + """ + signal_length = signal.shape[axis] + if pad_end: + frames_overlap = frame_length - frame_step + rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) + pad_size = int(frame_length - rest_samples) + + if pad_size != 0: + pad_axis = [0] * signal.ndim + pad_axis[axis] = pad_size + signal = F.pad(signal, pad_axis, "constant", pad_value) + frames = signal.unfold(axis, frame_length, frame_step) + return frames + + +def program_to_slakh_program(program): + # this is done very hackily, probably should use a custom mapping + for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): + if program >= slakh_program: + return slakh_program + + +def audio_to_frames( + samples, + hop_size: int, + frame_rate: int, +) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = hop_size + samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") + + # Split audio into frames. + frames = frame( + torch.Tensor(samples).unsqueeze(0), + frame_length=frame_size, + frame_step=frame_size, + pad_end=False, # TODO check why its off by 1 here when True + ) + + num_frames = len(samples) // frame_size + + times = np.arange(num_frames) / frame_rate + return frames, times + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for + note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) + times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] + values = [ + NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) + for note in notes + if not note.is_drum + ] + [ + NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) + for note in notes + ] + return times, values + + +def num_velocity_bins_from_codec(codec: Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range("velocity") + return hi - lo + + +# segment an array into segments of length n +def segment(a, n): + return [a[i : i + n] for i in range(0, len(a), n)] + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def note_event_data_to_events( + state: Optional[NoteEncodingState], + value: NoteEventData, + codec: Codec, +) -> Sequence[Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [Event("pitch", value.pitch)] + else: + num_velocity_bins = num_velocity_bins_from_codec(codec) + velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [Event("velocity", velocity_bin), Event("drum", value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [ + Event("program", value.program), + Event("velocity", velocity_bin), + Event("pitch", value.pitch), + ] + + +def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [Event("program", program), Event("pitch", pitch)] + events.append(Event("tie", 0)) + return events + + +def encode_and_index_events( + state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None +): + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio + frame. This can be used e.g. to prepend events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more Event objects. + codec: An Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more Event objects. + + Returns: + events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes + splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should always be true that + event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind="stable") + event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while ( + len(event_start_indices) < len(frame_times) + and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second + ): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events).astype(np.int32) + state_events = np.array(state_events).astype(np.int32) + event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + + outputs = [] + for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): + outputs.append( + { + "inputs": events, + "event_start_indices": start_indices, + "event_end_indices": end_indices, + "state_events": state_events, + "state_event_indices": event_indices, + } + ) + + return outputs + + +def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): + """Extract target sequence corresponding to audio token segment.""" + features = features.copy() + start_idx = features["event_start_indices"][0] + end_idx = features["event_end_indices"][-1] + + features[feature_key] = features[feature_key][start_idx:end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features["state_event_indices"][0] + state_event_end_idx = state_event_start_idx + 1 + while features["state_events"][state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features[feature_key] = np.concatenate( + [ + features["state_events"][state_event_start_idx:state_event_end_idx], + features[feature_key], + ], + axis=0, + ) + + return features + + +def map_midi_programs( + feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = PROGRAM_GRANULARITIES[granularity_type] + + feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) + return feature + + +def run_length_encode_shifts_fn( + features, + codec: Codec, + feature_key: str = "inputs", + state_change_event_types: Sequence[str] = (), +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones + will be removed. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] + + def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = np.array([], dtype=np.int32) + + current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) + + for event in events: + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # If this event is a state change and has the same value as the current + # state, we can skip it entirely. + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue + + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = np.minimum(codec.max_shift_steps, shift_steps) + output = np.concatenate([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = np.concatenate([output, [event]], axis=0) + + features[feature_key] = output + return features + + return run_length_encode_shifts(features) + + +def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): + tie_token = codec.encode_event(Event("tie", 0)) + state_events_end_token = tie_token if note_representation_config.include_ties else None + + features = extract_sequence_with_indices( + features, state_events_end_token=state_events_end_token, feature_key="inputs" + ) + + features = map_midi_programs(features, codec) + + features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) + + return features + + +class MidiProcessor: + def __init__(self): + self.codec = Codec( + max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, + steps_per_second=DEFAULT_STEPS_PER_SECOND, + event_ranges=[ + EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), + EventRange("tie", 0, 0), + EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), + EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + ], + ) + self.tokenizer = Tokenizer(self.codec.num_classes) + self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) + + def __call__(self, midi: Union[bytes, os.PathLike, str]): + if not isinstance(midi, bytes): + with open(midi, "rb") as f: + midi = f.read() + + ns = note_seq.midi_to_note_sequence(midi) + ns_sus = note_seq.apply_sustain_control_changes(ns) + + for note in ns_sus.notes: + if not note.is_drum: + note.program = program_to_slakh_program(note.program) + + samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) + + _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) + times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) + + events = encode_and_index_events( + state=NoteEncodingState(), + event_times=times, + event_values=values, + frame_times=frame_times, + codec=self.codec, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + ) + + events = [ + note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events + ] + input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] + + return input_tokens diff --git a/src/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py b/src/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py new file mode 100644 index 000000000000..94eaa176f3e5 --- /dev/null +++ b/src/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + max_length: int, + vocab_size: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.token_embedder = nn.Embedding(vocab_size, d_model) + + self.position_encoding = nn.Embedding(max_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + vocab_size=vocab_size, + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + dropout_rate=dropout_rate, + feed_forward_proj=feed_forward_proj, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_input_tokens, encoder_inputs_mask): + x = self.token_embedder(encoder_input_tokens) + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) + x += self.position_encoding(inputs_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_input_tokens.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py new file mode 100644 index 000000000000..66155ebf7f35 --- /dev/null +++ b/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -0,0 +1,210 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...models import T5FilmDecoder +from ...schedulers import DDPMScheduler +from ...utils import is_onnx_available, logging, randn_tensor + + +if is_onnx_available(): + from ..onnx_utils import OnnxRuntimeModel + +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .continous_encoder import SpectrogramContEncoder +from .notes_encoder import SpectrogramNotesEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TARGET_FEATURE_LENGTH = 256 + + +class SpectrogramDiffusionPipeline(DiffusionPipeline): + _optional_components = ["melgan"] + + def __init__( + self, + notes_encoder: SpectrogramNotesEncoder, + continuous_encoder: SpectrogramContEncoder, + decoder: T5FilmDecoder, + scheduler: DDPMScheduler, + melgan: OnnxRuntimeModel if is_onnx_available() else Any, + ) -> None: + super().__init__() + + # From MELGAN + self.min_value = math.log(1e-5) # Matches MelGAN training. + self.max_value = 4.0 # Largest value for most examples + self.n_dims = 128 + + self.register_modules( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + + def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): + """Linearly scale features to network outputs range.""" + min_out, max_out = output_range + if clip: + features = torch.clip(features, self.min_value, self.max_value) + # Scale to [0, 1]. + zero_one = (features - self.min_value) / (self.max_value - self.min_value) + # Scale to [min_out, max_out]. + return zero_one * (max_out - min_out) + min_out + + def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): + """Invert by linearly scaling network outputs to features range.""" + min_out, max_out = input_range + outputs = torch.clip(outputs, min_out, max_out) if clip else outputs + # Scale to [0, 1]. + zero_one = (outputs - min_out) / (max_out - min_out) + # Scale to [self.min_value, self.max_value]. + return zero_one * (self.max_value - self.min_value) + self.min_value + + def encode(self, input_tokens, continuous_inputs, continuous_mask): + tokens_mask = input_tokens > 0 + tokens_encoded, tokens_mask = self.notes_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time): + timesteps = noise_time + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(input_tokens.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + logits = self.decoder( + encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps + ) + return logits + + @torch.no_grad() + def __call__( + self, + input_tokens: List[List[int]], + generator: Optional[torch.Generator] = None, + num_inference_steps: int = 100, + return_dict: bool = True, + output_type: str = "numpy", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ) -> Union[AudioPipelineOutput, Tuple]: + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) + full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) + ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + + for i, encoder_input_tokens in enumerate(input_tokens): + if i == 0: + encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( + device=self.device, dtype=self.decoder.dtype + ) + # The first chunk has no previous context. + encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + else: + # The full song pipeline does not feed in a context feature, so the mask + # will be all 0s after the feature converter. Because we know we're + # feeding in a full context chunk from the previous prediction, set it + # to all 1s. + encoder_continuous_mask = ones + + encoder_continuous_inputs = self.scale_features( + encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True + ) + + encodings_and_masks = self.encode( + input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + ) + + # Sample encoder_continuous_inputs shaped gaussian noise to begin loop + x = randn_tensor( + shape=encoder_continuous_inputs.shape, + generator=generator, + device=self.device, + dtype=self.decoder.dtype, + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + # Denoising diffusion loop + for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + output = self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=x, + noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + ) + + # Compute previous output: x_t -> x_t-1 + x = self.scheduler.step(output, t, x, generator=generator).prev_sample + + mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) + encoder_continuous_inputs = mel[:1] + pred_mel = mel.cpu().float().numpy() + + full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, full_pred_mel) + + logger.info("Generated segment", i) + + if output_type == "numpy" and not is_onnx_available(): + raise ValueError( + "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." + ) + elif output_type == "numpy" and self.melgan is None: + raise ValueError( + "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." + ) + + if output_type == "numpy": + output = self.melgan(input_features=full_pred_mel.astype(np.float32)) + else: + output = full_pred_mel + + if not return_dict: + return (output,) + + return AudioPipelineOutput(audios=output) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 54ec4dabc73e..6bc2b58b5fef 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -51,6 +51,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline + from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline @@ -127,6 +128,7 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index ef4598433f82..a16213639526 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -274,18 +274,18 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - config = dict( - sample_size=image_size // vae_scale_factor, - in_channels=unet_params.in_channels, - down_block_types=tuple(down_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=unet_params.num_res_blocks, - cross_attention_dim=unet_params.context_dim, - attention_head_dim=head_dim, - use_linear_projection=use_linear_projection, - class_embed_type=class_embed_type, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - ) + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + } if not controlnet: config["out_channels"] = unet_params.out_channels @@ -305,16 +305,16 @@ def create_vae_diffusers_config(original_config, image_size: int): down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - config = dict( - sample_size=image_size, - in_channels=vae_params.in_channels, - out_channels=vae_params.out_ch, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=vae_params.z_channels, - layers_per_block=vae_params.num_res_blocks, - ) + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } return config @@ -989,6 +989,7 @@ def download_from_original_stable_diffusion_ckpt( stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, + load_safety_checker: bool = True, ) -> StableDiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1028,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt( The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. """ if prediction_type == "v-prediction": prediction_type = "v_prediction" @@ -1270,8 +1273,13 @@ def download_from_original_stable_diffusion_ckpt( elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + else: + safety_checker = None + feature_extractor = None if controlnet: pipe = StableDiffusionControlNetPipeline( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 76423867add1..dd8e4f16dfc0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -19,11 +19,12 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor @@ -44,7 +45,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): return noise -class CycleDiffusionPipeline(DiffusionPipeline): +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. @@ -142,7 +143,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -155,7 +156,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -320,8 +321,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -338,6 +339,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -398,6 +403,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 28718e4778fb..3b4f77029ce4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -24,7 +24,7 @@ from flax.training.common_utils import shard from packaging import version from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( @@ -103,7 +103,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -117,7 +117,7 @@ def __init__( FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() @@ -245,9 +245,12 @@ def _generate( negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + latents_shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py new file mode 100644 index 000000000000..df3e79a194f8 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -0,0 +1,537 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def image_grid(imgs, rows, cols): + ... w, h = imgs[0].size + ... grid = Image.new("RGB", size=(cols * w, rows * h)) + ... for i, img in enumerate(imgs): + ... grid.paste(img, box=(i % cols * w, i // cols * h)) + ... return grid + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the unet during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing the ControlNet input condition. ControlNet use this input condition to generate + guidance to Unet. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 97a3eb01c352..6a387af364b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -23,7 +23,7 @@ from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( @@ -127,7 +127,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -141,7 +141,7 @@ def __init__( FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() @@ -268,7 +268,7 @@ def _generate( latents_shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) @@ -520,7 +520,7 @@ def unshard(x: jnp.ndarray): def preprocess(image, dtype): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index d964207516bc..abb57f8b62e9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from flax.training.common_utils import shard from packaging import version from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( @@ -124,7 +124,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -138,7 +138,7 @@ def __init__( FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() @@ -563,7 +563,7 @@ def unshard(x: jnp.ndarray): def preprocess_image(image, dtype): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -572,7 +572,7 @@ def preprocess_image(image, dtype): def preprocess_mask(mask, dtype): w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h)) mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 mask = jnp.expand_dims(mask, axis=(0, 1)) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 55b996e56bb3..eb02f6cb321c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -38,7 +38,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] @@ -51,7 +51,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -111,7 +111,15 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -125,32 +133,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -170,7 +194,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -179,6 +203,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -188,9 +214,56 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds - def __call__( + def check_inputs( self, prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, @@ -200,28 +273,86 @@ def __call__( eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): - if isinstance(prompt, str): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + One or a list of [numpy generator(s)](TODO) to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random @@ -232,7 +363,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # get the initial random noise unless the user supplied it @@ -333,7 +469,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, ): deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 9123e5f3296d..67d3f44e6d4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -18,7 +18,7 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -40,7 +40,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -77,7 +77,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel @@ -87,7 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] @@ -100,7 +100,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -161,7 +161,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -175,32 +183,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -220,7 +244,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -229,6 +253,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -238,6 +264,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + def check_inputs( + self, + prompt: Union[str, List[str]], + callback_steps: int, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def __call__( self, prompt: Union[str, List[str]], @@ -249,6 +317,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -288,6 +358,13 @@ def __call__( [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -308,24 +385,21 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if generator is None: generator = np.random @@ -340,7 +414,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 46b5ce5ad6e4..0bb39c4b1c61 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -18,7 +18,7 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -77,7 +77,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel @@ -87,7 +87,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] @@ -100,7 +100,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -162,7 +162,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -176,32 +184,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -221,7 +245,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -230,6 +254,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -239,6 +265,54 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + @torch.no_grad() def __call__( self, @@ -254,6 +328,8 @@ def __call__( eta: float = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -300,6 +376,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -321,23 +404,18 @@ def __call__( (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random @@ -351,7 +429,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) num_channels_latents = NUM_LATENT_CHANNELS diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 84e5f6aaab01..8ef7a781451c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -4,7 +4,7 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -19,7 +19,7 @@ def preprocess(image): w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -29,7 +29,7 @@ def preprocess(image): def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) @@ -63,7 +63,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -75,7 +75,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor def __init__( self, @@ -86,7 +86,7 @@ def __init__( unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -147,7 +147,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -161,32 +169,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -206,7 +230,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -215,6 +239,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -224,6 +250,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def __call__( self, prompt: Union[str, List[str]], @@ -236,6 +304,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -280,6 +350,13 @@ def __call__( [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -300,24 +377,21 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if generator is None: generator = np.random @@ -333,7 +407,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 45b5a50467b0..8db19c2b9109 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -31,7 +31,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -70,16 +70,85 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise_level TODO + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs self.check_inputs(prompt, image, noise_level, callback_steps) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -88,7 +157,13 @@ def __call__( # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] @@ -199,45 +274,59 @@ def decode_latents(self, latents): image = image.transpose((0, 2, 3, 1)) return image - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device, + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - # if hasattr(text_inputs, "attention_mask"): - # attention_mask = text_inputs.attention_mask.to(device) - # else: - # attention_mask = None - - # no positional arguments to text_encoder - text_embeddings = self.text_encoder( - input_ids=text_input_ids.int().to(device), - # attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + # no positional arguments to text_encoder + prompt_embeds = self.text_encoder( + input_ids=text_input_ids.int().to(device), + # attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] - bs_embed, seq_len, _ = text_embeddings.shape + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) - text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -277,6 +366,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) uncond_embeddings = uncond_embeddings[0] + if do_classifier_free_guidance: seq_len = uncond_embeddings.shape[1] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) @@ -285,6 +375,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) - return text_embeddings + return prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 81b2cfa9bc3e..689febe3e891 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,9 +17,10 @@ import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -52,7 +53,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -76,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -89,7 +90,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -297,8 +298,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -315,6 +316,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -375,6 +380,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -554,8 +563,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -640,7 +649,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 2d32c0ba8b62..35351bae7116 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -19,8 +19,9 @@ import numpy as np import torch from torch.nn import functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers @@ -159,7 +160,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): +class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. @@ -183,7 +184,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -196,7 +197,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -317,8 +318,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -335,6 +336,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -395,6 +400,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -741,8 +750,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -846,7 +855,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index aeb70b1b2234..12d21afbfeda 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -21,8 +21,9 @@ import PIL.Image import torch from torch import nn -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.controlnet import ControlNetOutput from ...models.modeling_utils import ModelMixin @@ -146,7 +147,7 @@ def forward( return down_block_res_samples, mid_block_res_sample -class StableDiffusionControlNetPipeline(DiffusionPipeline): +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -174,7 +175,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -188,7 +189,7 @@ def __init__( controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -336,8 +337,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -354,6 +355,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -414,6 +419,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -537,15 +546,27 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Check `image` + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` if isinstance(self.controlnet, ControlNetModel): self.check_image(image, prompt, prompt_embeds) elif isinstance(self.controlnet, MultiControlNetModel): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") - if len(image) != len(self.controlnet.nets): + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): raise ValueError( "For multiple controlnets: `image` must have the same length as the number of controlnets." ) @@ -556,12 +577,14 @@ def check_inputs( assert False # Check `controlnet_conditioning_scale` - if isinstance(self.controlnet, ControlNetModel): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif isinstance(self.controlnet, MultiControlNetModel): - if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( @@ -675,7 +698,7 @@ def _default_height_width(self, height, width, image): if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): - height = image.shape[3] + height = image.shape[2] height = (height // 8) * 8 # round down to nearest multiple of 8 @@ -683,7 +706,7 @@ def _default_height_width(self, height, width, image): if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): - width = image.shape[2] + width = image.shape[3] width = (width // 8) * 8 # round down to nearest multiple of 8 @@ -755,8 +778,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -887,7 +910,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index b66cfe9b437e..54f00ebc23f2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor @@ -41,7 +42,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -54,7 +55,7 @@ def preprocess(image): return image -class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. @@ -182,8 +183,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -200,6 +201,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -260,6 +265,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -442,7 +451,7 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui if isinstance(image, PIL.Image.Image): image = [image] else: - image = [img for img in image] + image = list(image) if isinstance(image[0], PIL.Image.Image): width, height = image[0].size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index a7165457c67c..d543593fdbf5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -18,7 +18,7 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel @@ -53,7 +53,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode images (if they are in PIL format), @@ -67,7 +67,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -284,7 +284,7 @@ def __call__( The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor` + `CLIPImageProcessor` height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -358,7 +358,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 835c88e19448..a0befdae73c4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -13,16 +13,17 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -78,7 +79,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -91,7 +92,7 @@ def preprocess(image): return image -class StableDiffusionImg2ImgPipeline(DiffusionPipeline): +class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. @@ -115,7 +116,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -128,7 +129,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -311,8 +312,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -329,6 +330,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -389,6 +394,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -577,6 +586,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -634,6 +644,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: @@ -695,7 +709,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -723,14 +742,15 @@ def __call__( image = latents has_nsfw_concept = None - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: - has_nsfw_concept = False + image = self.decode_latents(latents) + + if self.safety_checker is not None: + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + has_nsfw_concept = False - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cee7ace239db..8e0ea5a8d079 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -19,9 +19,10 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask): return mask, masked_image -class StableDiffusionInpaintPipeline(DiffusionPipeline): +class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. @@ -161,7 +162,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -174,7 +175,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -243,6 +244,14 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 9: + logger.warning( + f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," + f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`," + " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify" + " this behavior, please check whether you have loaded the right checkpoint." + ) self.register_modules( vae=vae, @@ -355,8 +364,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -373,6 +382,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -433,6 +446,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index cb953a7803b2..b7a0c942bbe2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -19,9 +19,10 @@ import PIL import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -42,7 +43,7 @@ def preprocess_image(image): w, h = image.size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -54,7 +55,7 @@ def preprocess_mask(mask, scale_factor=8): if not isinstance(mask, torch.FloatTensor): mask = mask.convert("L") w, h = mask.size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) @@ -76,12 +77,12 @@ def preprocess_mask(mask, scale_factor=8): # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape mask = mask.mean(dim=1, keepdim=True) h, w = mask.shape[-2:] - h, w = map(lambda x: x - x % 8, (h, w)) # resize to integer multiple of 8 + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) return mask -class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): +class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. @@ -105,7 +106,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["feature_extractor"] @@ -119,7 +120,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -299,8 +300,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -317,6 +318,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -377,6 +382,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -521,7 +530,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, @@ -611,10 +620,16 @@ def __call__( (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 06ab580d492f..f7999a08dc9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -18,8 +18,9 @@ import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -47,7 +48,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -60,7 +61,7 @@ def preprocess(image): return image -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): +class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. @@ -84,7 +85,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -97,7 +98,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -493,8 +494,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_ prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -511,6 +512,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -571,6 +576,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 2d40390b41d1..99aca66db809 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -17,7 +17,9 @@ import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser +from k_diffusion.sampling import get_sigmas_karras +from ...loaders import TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -41,7 +43,7 @@ def apply_model(self, *args, **kwargs): return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample -class StableDiffusionKDiffusionPipeline(DiffusionPipeline): +class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -71,7 +73,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -111,7 +113,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) model = ModelWrapper(unet, scheduler.alphas_cumprod) - if scheduler.prediction_type == "v_prediction": + if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) @@ -220,8 +222,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -238,6 +240,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -298,6 +304,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -354,10 +364,17 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -369,6 +386,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: @@ -400,6 +443,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + use_karras_sigmas: Optional[bool] = False, ): r""" Function invoked when calling the pipeline for generation. @@ -456,7 +500,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to + `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M + Karras`. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. @@ -469,10 +516,18 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -494,11 +549,19 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) - sigmas = self.scheduler.sigmas + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas sigmas = sigmas.to(prompt_embeds.dtype) - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -513,7 +576,7 @@ def __call__( self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) - # 6. Define model function + # 7. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) @@ -524,16 +587,16 @@ def model_fn(x, t): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred - # 7. Run k-diffusion solver + # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) - # 8. Post-processing + # 9. Post-processing image = self.decode_latents(latents) - # 9. Run safety checker + # 10. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 624d0e625828..822bd49ce31c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -38,7 +38,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py new file mode 100644 index 000000000000..b7ded03d529b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -0,0 +1,796 @@ +# Copyright 2023 TIME Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...loaders import TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...schedulers.scheduling_utils import SchedulerMixin +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionModelEditingPipeline + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + + >>> pipe = pipe.to("cuda") + + >>> source_prompt = "A pack of roses" + >>> destination_prompt = "A pack of blue roses" + >>> pipe.edit_model(source_prompt, destination_prompt) + + >>> prompt = "A field of roses" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + with_to_k ([`bool`]): + Whether to edit the key projection matrices along wiht the value projection matrices. + with_augs ([`list`]): + Textual augmentations to apply while editing the text-to-image model. Set to [] for no augmentations. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + with_to_k: bool = True, + with_augs: list = AUGS_CONST, + ): + super().__init__() + + if isinstance(scheduler, PNDMScheduler): + logger.error("PNDMScheduler for this pipeline is currently not supported.") + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.with_to_k = with_to_k + self.with_augs = with_augs + + # get cross-attention layers + ca_layers = [] + + def append_ca(net_): + if net_.__class__.__name__ == "CrossAttention": + ca_layers.append(net_) + elif hasattr(net_, "children"): + for net__ in net_.children(): + append_ca(net__) + + # recursively find all cross-attention layers in unet + for net in self.unet.named_children(): + if "down" in net[0]: + append_ca(net[1]) + elif "up" in net[0]: + append_ca(net[1]) + elif "mid" in net[0]: + append_ca(net[1]) + + # get projection matrices + self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] + self.projection_matrices = [l.to_v for l in self.ca_clip_layers] + self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] + if self.with_to_k: + self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] + self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def edit_model( + self, + source_prompt: str, + destination_prompt: str, + lamb: float = 0.1, + restart_params: bool = True, + ): + r""" + Apply model editing via closed-form solution (see Eq. 5 in the TIME paper https://arxiv.org/abs/2303.08084) + + Args: + source_prompt (`str`): + The source prompt containing the concept to be edited. + destination_prompt (`str`): + The destination prompt. Must contain all words from source_prompt with additional ones to specify the + target edit. + lamb (`float`, *optional*, defaults to 0.1): + The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. + restart_params (`bool`, *optional*, defaults to True): + Restart the model parameters to their pre-trained version before editing. This is done to avoid edit + compounding. When it is False, edits accumulate. + """ + + # restart LDM parameters + if restart_params: + num_ca_clip_layers = len(self.ca_clip_layers) + for idx_, l in enumerate(self.ca_clip_layers): + l.to_v = copy.deepcopy(self.og_matrices[idx_]) + self.projection_matrices[idx_] = l.to_v + if self.with_to_k: + l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) + self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k + + # set up sentences + old_texts = [source_prompt] + new_texts = [destination_prompt] + # add augmentations + base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] + for aug in self.with_augs: + old_texts.append(aug + base) + base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] + for aug in self.with_augs: + new_texts.append(aug + base) + + # prepare input k* and v* + old_embs, new_embs = [], [] + for old_text, new_text in zip(old_texts, new_texts): + text_input = self.tokenizer( + [old_text, new_text], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + old_emb, new_emb = text_embeddings + old_embs.append(old_emb) + new_embs.append(new_emb) + + # identify corresponding destinations for each token in old_emb + idxs_replaces = [] + for old_text, new_text in zip(old_texts, new_texts): + tokens_a = self.tokenizer(old_text).input_ids + tokens_b = self.tokenizer(new_text).input_ids + tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] + tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] + num_orig_tokens = len(tokens_a) + idxs_replace = [] + j = 0 + for i in range(num_orig_tokens): + curr_token = tokens_a[i] + while tokens_b[j] != curr_token: + j += 1 + idxs_replace.append(j) + j += 1 + while j < 77: + idxs_replace.append(j) + j += 1 + while len(idxs_replace) < 77: + idxs_replace.append(76) + idxs_replaces.append(idxs_replace) + + # prepare batch: for each pair of setences, old context and new values + contexts, valuess = [], [] + for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): + context = old_emb.detach() + values = [] + with torch.no_grad(): + for layer in self.projection_matrices: + values.append(layer(new_emb[idxs_replace]).detach()) + contexts.append(context) + valuess.append(values) + + # edit the model + for layer_num in range(len(self.projection_matrices)): + # mat1 = \lambda W + \sum{v k^T} + mat1 = lamb * self.projection_matrices[layer_num].weight + + # mat2 = \lambda I + \sum{k k^T} + mat2 = lamb * torch.eye( + self.projection_matrices[layer_num].weight.shape[1], + device=self.projection_matrices[layer_num].weight.device, + ) + + # aggregate sums for mat1, mat2 + for context, values in zip(contexts, valuess): + context_vector = context.reshape(context.shape[0], context.shape[1], 1) + context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) + value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) + for_mat1 = (value_vector @ context_vector_T).sum(dim=0) + for_mat2 = (context_vector @ context_vector_T).sum(dim=0) + mat1 += for_mat1 + mat2 += for_mat2 + + # update projection matrix + self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 3fea4c2d83bb..392b2a72a76f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -15,8 +15,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, PNDMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring @@ -47,7 +48,7 @@ """ -class StableDiffusionPanoramaPipeline(DiffusionPipeline): +class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation". @@ -75,7 +76,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -88,7 +89,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -212,8 +213,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -230,6 +231,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -290,6 +295,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -491,8 +500,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -577,7 +586,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -616,7 +625,9 @@ def __call__( latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 9c928129d0b9..0239c8128171 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -23,11 +23,12 @@ from transformers import ( BlipForConditionalGeneration, BlipProcessor, - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, ) +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler @@ -50,7 +51,7 @@ @dataclass -class Pix2PixInversionPipelineOutput(BaseOutput): +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): """ Output class for Stable Diffusion pipelines. @@ -180,7 +181,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -242,8 +243,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -297,7 +298,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the @@ -318,7 +319,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, safety_checker: StableDiffusionSafetyChecker, inverse_scheduler: DDIMInverseScheduler, caption_generator: BlipForConditionalGeneration, @@ -452,8 +453,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -470,6 +471,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -530,6 +535,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -828,8 +837,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -920,7 +929,7 @@ def __call__( # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index b24354a8e568..ebac58e18f62 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -17,8 +17,9 @@ import torch import torch.nn.functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring @@ -64,8 +65,8 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -87,7 +88,7 @@ def __call__( # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input -class StableDiffusionSAGPipeline(DiffusionPipeline): +class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -111,7 +112,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -124,7 +125,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -229,8 +230,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -247,6 +248,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -307,6 +312,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -496,8 +505,8 @@ def __call__( https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -565,7 +574,7 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # and `sag_scale` is` `s` of equation (15) + # and `sag_scale` is` `s` of equation (16) # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 @@ -586,7 +595,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -636,7 +645,7 @@ def get_map_size(module, input, output): # perform self-attention guidance with the stored self-attentnion map if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map - # and we only use unconditional one according to equation (24) + # and we only use unconditional one according to equation (25) # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 @@ -692,7 +701,7 @@ def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape - h = self.unet.attention_head_dim + h = self.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 9f8f44a12bb4..c0086b32d6fd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor @@ -37,7 +38,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -50,7 +51,7 @@ def preprocess(image): return image -class StableDiffusionUpscalePipeline(DiffusionPipeline): +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. @@ -176,8 +177,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -194,6 +195,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -254,6 +259,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -317,10 +326,50 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def check_inputs(self, prompt, image, noise_level, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) @@ -480,13 +529,27 @@ def __call__( """ # 1. Check inputs - self.check_inputs(prompt, image, noise_level, callback_steps) + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index a8ba0b504628..fafb8d1d2800 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -19,10 +19,11 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -47,7 +48,7 @@ """ -class StableUnCLIPPipeline(DiffusionPipeline): +class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): """ Pipeline for text-to-image generation using stable unCLIP. @@ -178,6 +179,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -324,8 +350,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -342,6 +368,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -402,6 +432,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -581,6 +615,7 @@ def noise_image_embeddings( noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) @@ -650,8 +685,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -842,7 +877,7 @@ def __call__( timesteps = self.scheduler.timesteps # 11. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) latents = self.prepare_latents( shape=shape, @@ -884,6 +919,10 @@ def __call__( # 14. Post-processing image = self.decode_latents(latents) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + # 15. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 99caa8be65a5..22b7280f3679 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -17,14 +17,15 @@ import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.utils.import_utils import is_accelerate_available +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor, replace_example_docstring +from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -60,7 +61,7 @@ """ -class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. @@ -68,7 +69,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Feature extractor for image pre-processing before being encoded. image_encoder ([`CLIPVisionModelWithProjection`]): CLIP vision model for encoding images. @@ -91,7 +92,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): """ # image encoding components - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection # image noising components @@ -109,7 +110,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): def __init__( self, # image encoding components - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, # image noising components image_normalizer: StableUnCLIPImageNormalizer, @@ -180,6 +181,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -224,8 +250,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -242,6 +268,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -302,6 +332,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -363,7 +397,7 @@ def _encode_image( # what the expected dimensions of inputs should be and how we handle the encoding. repeat_by = num_images_per_prompt - if not image_embeds: + if image_embeds is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values @@ -548,6 +582,7 @@ def noise_image_embeddings( noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) @@ -571,8 +606,8 @@ def noise_image_embeddings( @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, @@ -597,8 +632,8 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be + used or prompt is initialized to `""`. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the @@ -619,8 +654,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -674,6 +709,9 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + if prompt is None and prompt_embeds is None: + prompt = len(image) * [""] if isinstance(image, list) else "" + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -734,7 +772,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, @@ -777,6 +815,10 @@ def __call__( # 9. Post-processing image = self.decode_latents(latents) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py index 9c7f190d0505..7362df7e80e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +++ b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch from torch import nn @@ -37,6 +39,15 @@ def __init__( self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim)) + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + def scale(self, embeds): embeds = (embeds - self.mean) * 1.0 / self.std return embeds diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 3d0ddce7157e..87e7b3e6c9eb 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -5,7 +5,7 @@ import numpy as np import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel @@ -45,7 +45,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -59,7 +59,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -623,7 +623,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 4535500e2592..2e0ab15eb975 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -120,7 +120,7 @@ def __call__( sample = (sample / 2 + 0.5).clamp(0, 1) image = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": - image = self.numpy_to_pil(sample) + image = self.numpy_to_pil(image) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index c2437857a23a..165a1a0f0d98 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -29,3 +29,4 @@ class TextToVideoSDPipelineOutput(BaseOutput): from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401 + from .pipeline_text_to_video_zero import TextToVideoZeroPipeline diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 453809ef6df7..6fc89e945604 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - return images -class TextToVideoSDPipeline(DiffusionPipeline): +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-video generation. @@ -238,8 +239,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -256,6 +257,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -316,6 +321,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -513,8 +522,8 @@ def __call__( usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -597,7 +606,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py new file mode 100644 index 000000000000..cf5e6e399a77 --- /dev/null +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -0,0 +1,541 @@ +import copy +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import BaseOutput + + +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be + equal to 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Sparse Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +@dataclass +class TextToVideoPipelineOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +class TextToVideoZeroPipeline(StableDiffusionPipeline): + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion. + + This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods + the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) + + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: latent code at time t0 + t0: t0 + t1: t1 + generator: torch.Generator object + + Returns: + x_t1: forward process applied to x_t0 from time t0 to t1. + """ + eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + cross_attention_kwargs=None, + ): + """ + Perform backward process given list of time steps + + Args: + latents: Latents at time timesteps[0]. + timesteps: time steps, along which to perform backward process. + prompt_embeds: Pre-generated text embeddings + guidance_scale: + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + extra_step_kwargs: extra_step_kwargs. + cross_attention_kwargs: cross_attention_kwargs. + num_warmup_steps: number of warmup steps. + + Returns: + latents: latents of backward process output at time timesteps[-1] + """ + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + return latents.clone().detach() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int] = 8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + t0: int = 44, + t1: int = 47, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + video_length (`int`, *optional*, defaults to 8): The number of generated video frames + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"numpy"`): + The output format of the generated image. Choose between `"latent"` and `"numpy"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), + Sect. 3.3.1. + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]: + The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent + codes of generated image, and a list of `bool`s denoting whether the corresponding generated image + likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + assert video_length > 0 + frame_ids = list(range(video_length)) + + assert num_videos_per_prompt == 1 + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + ) + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].item(), + t1=timesteps[-t1 - 1].item(), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + self.scheduler = scheduler_copy + x_1k_0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 :], + prompt_embeds=prompt_embeds, + latents=x_1k_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + latents = x_1k_0 + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + image = self.decode_latents(latents) + # Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index e5e766846841..56d522354d9a 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -19,7 +19,7 @@ import torch from torch.nn import functional as F from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, @@ -48,7 +48,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of @@ -73,7 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer - feature_extractor: CLIPFeatureExtractor + feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection super_res_first: UNet2DModel super_res_last: UNet2DModel @@ -87,7 +87,7 @@ def __init__( text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, super_res_first: UNet2DModel, super_res_last: UNet2DModel, @@ -264,7 +264,7 @@ def __call__( The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. + `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py index 0a54c3319f28..0414559500c1 100644 --- a/src/diffusers/pipelines/unclip/text_proj.py +++ b/src/diffusers/pipelines/unclip/text_proj.py @@ -77,10 +77,10 @@ def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) - text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index dd5410dbc0b0..35ddfcadc3cb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -3,16 +3,22 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor +from ...models.attention_processor import ( + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging +from ...utils import deprecate, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -167,18 +173,24 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, it will skip the normalization and activation layers in post-processing norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + encoder_hid_dim (`int`, *optional*, defaults to None): + If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, or `"projection"`. + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_act_fn (`str`, *optional*, default to `None`): + Optional activation function to use on the time embeddings only one time before they as passed to the rest + of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str, *optional*, default to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, default to `None`): @@ -187,6 +199,13 @@ class conditioning with `class_embed_type` equal to `None`. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the + `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will + default to `False`. """ _supports_gradient_checkpointing = True @@ -215,13 +234,14 @@ def __init__( ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + encoder_hid_dim: Optional[int] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -229,12 +249,18 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, ): super().__init__() @@ -265,6 +291,18 @@ def __init__( f" {attention_head_dim}. `down_block_types`: {down_block_types}." ) + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:" + f" {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" + f" {layers_per_block}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( @@ -298,6 +336,11 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) + if encoder_hid_dim is not None: + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -318,18 +361,57 @@ def __init__( # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None + if time_embedding_act_fn is None: + self.time_embed_act = None + elif time_embedding_act_fn == "swish": + self.time_embed_act = lambda x: F.silu(x) + elif time_embedding_act_fn == "mish": + self.time_embed_act = nn.Mish() + elif time_embedding_act_fn == "silu": + self.time_embed_act = nn.SiLU() + elif time_embedding_act_fn == "gelu": + self.time_embed_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -339,15 +421,15 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[i], attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, @@ -355,6 +437,9 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.down_blocks.append(down_block) @@ -362,12 +447,12 @@ def __init__( if mid_block_type == "UNetMidBlockFlatCrossAttn": self.mid_block = UNetMidBlockFlatCrossAttn( in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, @@ -377,14 +462,17 @@ def __init__( elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": self.mid_block = UNetMidBlockFlatSimpleCrossAttn( in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None @@ -397,6 +485,8 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -416,22 +506,25 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=reversed_cross_attention_dim[i], attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -451,6 +544,19 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + ( + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use" + " `unet.config.in_channels` instead" + ), + standard_warn=False, + ) + return self.config.in_channels + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" @@ -505,6 +611,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -661,7 +773,17 @@ def forward( class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None: + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) # 2. pre-process sample = self.conv_in(sample) @@ -1396,6 +1518,9 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() @@ -1419,11 +1544,16 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + attentions.append( Attention( query_dim=in_channels, @@ -1434,7 +1564,9 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=AttnAddedKVProcessor(), + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, ) ) resnets.append( @@ -1449,6 +1581,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index f482ef11940a..6d6b5e7863eb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -3,7 +3,7 @@ import PIL.Image import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -41,12 +41,12 @@ class VersatileDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor + image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModel image_encoder: CLIPVisionModel image_unet: UNet2DConditionModel @@ -57,7 +57,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): def __init__( self, tokenizer: CLIPTokenizer, - image_feature_extractor: CLIPFeatureExtractor, + image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModel, image_unet: UNet2DConditionModel, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 529d9a2ae9c0..0f385ed6612c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -20,7 +20,7 @@ import torch import torch.utils.checkpoint from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, @@ -55,7 +55,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor + image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel @@ -68,7 +68,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): def __init__( self, tokenizer: CLIPTokenizer, - image_feature_extractor: CLIPFeatureExtractor, + image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModelWithProjection, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index fd6855af3852..2b47184d7773 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -19,7 +19,7 @@ import PIL import torch import torch.utils.checkpoint -from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -48,7 +48,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ - image_feature_extractor: CLIPFeatureExtractor + image_feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL @@ -56,7 +56,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): def __init__( self, - image_feature_extractor: CLIPFeatureExtractor, + image_feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, @@ -134,7 +134,7 @@ def normalize_embeddings(encoder_output): return embeds if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: - prompt = [p for p in prompt] + prompt = list(prompt) batch_size = len(prompt) if isinstance(prompt, list) else 1 diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index d1bb754c7b58..fdca625fd99d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -17,7 +17,7 @@ import torch import torch.utils.checkpoint -from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -48,7 +48,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor + image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 29a79d391e55..6b62d8893482 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -201,15 +201,38 @@ def _get_variance(self, timestep, prev_timestep): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -315,14 +338,13 @@ def step( ) # 4. Clip or threshold "predicted x_0" - if self.config.clip_sample: + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - # 5. compute variance: "sigma_t(Ξ·)" -> see formula (16) # Οƒ_t = sqrt((1 βˆ’ Ξ±_tβˆ’1)/(1 βˆ’ Ξ±_t)) * sqrt(1 βˆ’ Ξ±_t/Ξ±_tβˆ’1) variance = self._get_variance(timestep, prev_timestep) @@ -358,6 +380,7 @@ def step( return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -365,15 +388,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) @@ -381,19 +404,20 @@ def add_noise( noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 206294066cb3..eaaf497f9c1d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, deprecate, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -167,6 +167,16 @@ def __init__( self.variance_type = variance_type + @property + def num_train_timesteps(self): + deprecate( + "num_train_timesteps", + "1.0.0", + "Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`", + standard_warn=False, + ) + return self.config.num_train_timesteps + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -215,15 +225,18 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t + # we always take the log of variance, so clamp it to ensure it's not 0 + variance = torch.clamp(variance, min=1e-20) + if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": - variance = torch.clamp(variance, min=1e-20) + variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": - variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t @@ -234,22 +247,45 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) - max_log = torch.log(self.betas[t]) + max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def step( self, @@ -309,14 +345,13 @@ def step( ) # 3. Clip or threshold "predicted x_0" - if self.config.clip_sample: + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t @@ -355,15 +390,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) @@ -375,15 +410,15 @@ def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 39f8f17df5d3..8ea001a882d0 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -181,14 +181,22 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order @@ -196,15 +204,38 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -236,11 +267,7 @@ def convert_model_output( ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] @@ -458,6 +485,7 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -465,15 +493,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 474d9b0d7339..3399ee2c54cb 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {Οƒi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -136,6 +139,7 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -181,6 +185,7 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -192,14 +197,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order @@ -207,15 +227,76 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -256,11 +337,8 @@ def convert_model_output( ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": @@ -507,6 +585,7 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -514,15 +593,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index a02171a2df91..049e2b1dbd4d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -190,8 +190,8 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: the number of diffusion steps used when generating samples with a pre-trained model. """ steps = num_inference_steps - order = self.solver_order - if self.lower_order_final: + order = self.config.solver_order + if self.config.lower_order_final: if order == 3: if steps % 3 == 0: orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] @@ -227,7 +227,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -239,15 +239,38 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -288,11 +311,8 @@ def convert_model_output( ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": @@ -582,6 +602,7 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -589,15 +610,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 1b517bdec570..6b08e9bfc207 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -279,6 +279,7 @@ def step( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -286,19 +287,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index d6252904fd9a..7237128cbf07 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): interpolation_type (`str`, default `"linear"`, optional): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {Οƒi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -118,6 +122,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -149,6 +154,7 @@ def __init__( timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + self.use_karras_sigmas = use_karras_sigmas def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] @@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) @@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic " 'linear' or 'log_linear'" ) + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): @@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic else: self.timesteps = torch.from_numpy(timesteps).to(device=device) + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: torch.FloatTensor, @@ -312,19 +360,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index f7f1467fc53a..c1fd7b4967bc 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -112,8 +112,12 @@ def __init__( # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + if self.state_in_first_order: pos = -1 else: @@ -277,18 +281,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t) for t in timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index c8b1f2c3bedf..2fa0431e1292 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -114,8 +114,13 @@ def __init__( # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + if self.state_in_first_order: pos = -1 else: @@ -201,7 +206,7 @@ def set_timesteps( else: timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -323,6 +328,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -330,18 +336,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t) for t in timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 809da798f889..bb80c4a54bfe 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -113,8 +113,13 @@ def __init__( # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + if self.state_in_first_order: pos = -1 else: @@ -190,7 +195,7 @@ def set_timesteps( timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -304,6 +309,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -311,18 +317,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t) for t in timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0fe1f77f9b5c..68a8e1bddc01 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -284,6 +284,7 @@ def step( return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 562cefb17893..01c02a21bbfc 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -398,22 +398,23 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): return prev_sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, - ) -> torch.Tensor: + ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index e4f38d0f5dad..2cce68f7d962 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -194,33 +194,64 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: - self.solver_p.set_timesteps(num_inference_steps, device=device) + self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -253,11 +284,8 @@ def convert_model_output( ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred else: if self.config.prediction_type == "epsilon": @@ -584,6 +612,7 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -591,15 +620,15 @@ def add_noise( timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d803b053be71..c717d722f84c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate @@ -37,6 +38,8 @@ from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import ( HF_HUB_OFFLINE, + _add_variant, + _get_model_file, extract_commit_hash, http_user_agent, ) @@ -55,6 +58,7 @@ is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, + is_note_seq_available, is_omegaconf_available, is_onnx_available, is_safetensors_available, @@ -73,7 +77,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION -from .torch_utils import randn_tensor +from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): @@ -82,9 +86,11 @@ load_hf_numpy, load_image, load_numpy, + load_pt, nightly, parse_flag_from_env, print_tensor_test, + require_torch_2, require_torch_gpu, skip_mps, slow, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9e60a2a873b..1134ba6fb656 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,3 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 5db4c7d58d1e..162bac1c4331 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 7772c1a06b49..2bb80d136f33 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FlaxControlNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/diffusers/utils/dummy_note_seq_objects.py b/src/diffusers/utils/dummy_note_seq_objects.py new file mode 100644 index 000000000000..c02d0b015aed --- /dev/null +++ b/src/diffusers/utils/dummy_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class MidiProcessor(metaclass=DummyObject): + _backends = ["note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["note_seq"]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 700a3080fa11..014e193aa32a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class T5FilmDecoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5a28ce8cb04e..8a521457f2e3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class TextualInversionLoaderMixin(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -32,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AudioLDMPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -227,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionModelEditingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionPanoramaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -362,6 +407,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class TextToVideoZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py b/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py new file mode 100644 index 000000000000..fbde04e33f0a --- /dev/null +++ b/src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class SpectrogramDiffusionPipeline(metaclass=DummyObject): + _backends = ["transformers", "torch", "note_seq"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["transformers", "torch", "note_seq"]) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 916b18d35e7e..511763ec6687 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -18,16 +18,30 @@ import re import sys import traceback +import warnings from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 -from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami +from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami from huggingface_hub.file_download import REGEX_COMMIT_HASH -from huggingface_hub.utils import is_jinja_available +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + is_jinja_available, +) +from packaging import version +from requests import HTTPError from .. import __version__ -from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .constants import ( + DEPRECATED_REVISION_ARGS, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, @@ -215,3 +229,130 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " "the directory exists and can be written to." ) + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, + commit_hash=None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + # 1. First check if deprecated way of loading from branches is used + if ( + revision in DEPRECATED_REVISION_ARGS + and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) + and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") + ): + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=_add_variant(weights_name, revision), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + warnings.warn( + f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + return model_file + except: # noqa: E722 + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", + FutureWarning, + ) + try: + # 2. Load model file as usual + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision or commit_hash, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 3c09cb24f965..fd7538b1b5e9 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -153,9 +153,12 @@ candidates = ( "onnxruntime", "onnxruntime-gpu", + "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", + "onnxruntime-rocm", + "onnxruntime-training", ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu @@ -172,9 +175,22 @@ # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. # _opencv_available = importlib.util.find_spec("opencv-python") is not None try: - _opencv_version = importlib_metadata.version("opencv-python") - _opencv_available = True - logger.debug(f"Successfully imported cv2 version {_opencv_version}") + candidates = ( + "opencv-python", + "opencv-contrib-python", + "opencv-python-headless", + "opencv-contrib-python-headless", + ) + _opencv_version = None + for pkg in candidates: + try: + _opencv_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _opencv_available = _opencv_version is not None + if _opencv_available: + logger.debug(f"Successfully imported cv2 version {_opencv_version}") except importlib_metadata.PackageNotFoundError: _opencv_available = False @@ -218,6 +234,13 @@ except importlib_metadata.PackageNotFoundError: _k_diffusion_available = False +_note_seq_available = importlib.util.find_spec("note_seq") is not None +try: + _note_seq_version = importlib_metadata.version("note_seq") + logger.debug(f"Successfully imported note-seq version {_note_seq_version}") +except importlib_metadata.PackageNotFoundError: + _note_seq_available = False + _wandb_available = importlib.util.find_spec("wandb") is not None try: _wandb_version = importlib_metadata.version("wandb") @@ -304,6 +327,10 @@ def is_k_diffusion_available(): return _k_diffusion_available +def is_note_seq_available(): + return _note_seq_available + + def is_wandb_available(): return _wandb_available @@ -380,6 +407,12 @@ def is_compel_available(): install k-diffusion` """ +# docstyle-ignore +NOTE_SEQ_IMPORT_ERROR = """ +{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip +install note-seq` +""" + # docstyle-ignore WANDB_IMPORT_ERROR = """ {0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip @@ -416,6 +449,7 @@ def is_compel_available(): ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index f91a49b7a8a7..b6e8a219e129 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -84,7 +84,7 @@ def update(self, *args, **kwargs): def __getitem__(self, k): if isinstance(k, str): - inner_dict = {k: v for (k, v) in self.items()} + inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a3b8029f828..afea0540b765 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -21,9 +21,11 @@ BACKENDS_MAPPING, is_compel_available, is_flax_available, + is_note_seq_available, is_onnx_available, is_opencv_available, is_torch_available, + is_torch_version, ) from .logging import get_logger @@ -164,6 +166,15 @@ def require_torch(test_case): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( @@ -198,6 +209,13 @@ def require_onnxruntime(test_case): return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) +def require_note_seq(test_case): + """ + Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. + """ + return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) + + def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 113e64c16bac..b9815cbceede 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available +from .import_utils import is_torch_available, is_torch_version if is_torch_available(): @@ -68,3 +68,10 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents + + +def is_compiled_module(module): + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 9119ae30f42f..0bb10c3d5185 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -73,7 +73,7 @@ def __call__( # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) image = image.to(self.device) diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py index a8af08d3980a..494c5a1a4e95 100644 --- a/tests/fixtures/custom_pipeline/what_ever.py +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -73,7 +73,7 @@ def __call__( # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) image = image.to(self.device) diff --git a/tests/fixtures/elise_format0.mid b/tests/fixtures/elise_format0.mid new file mode 100644 index 000000000000..33dbabe7ab1d Binary files /dev/null and b/tests/fixtures/elise_format0.mid differ diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py new file mode 100644 index 000000000000..172d6d4d91fc --- /dev/null +++ b/tests/models/test_attention_processor.py @@ -0,0 +1,75 @@ +import unittest + +import torch + +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor + + +class AttnAddedKVProcessorTests(unittest.TestCase): + def get_constructor_arguments(self, only_cross_attention: bool = False): + query_dim = 10 + + if only_cross_attention: + cross_attention_dim = 12 + else: + # when only cross attention is not set, the cross attention dim must be the same as the query dim + cross_attention_dim = query_dim + + return { + "query_dim": query_dim, + "cross_attention_dim": cross_attention_dim, + "heads": 2, + "dim_head": 4, + "added_kv_proj_dim": 6, + "norm_num_groups": 1, + "only_cross_attention": only_cross_attention, + "processor": AttnAddedKVProcessor(), + } + + def get_forward_arguments(self, query_dim, added_kv_proj_dim): + batch_size = 2 + + hidden_states = torch.rand(batch_size, query_dim, 3, 2) + encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim) + attention_mask = None + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + } + + def test_only_cross_attention(self): + # self and cross attention + + torch.manual_seed(0) + + constructor_args = self.get_constructor_arguments(only_cross_attention=False) + attn = Attention(**constructor_args) + + self.assertTrue(attn.to_k is not None) + self.assertTrue(attn.to_v is not None) + + forward_args = self.get_forward_arguments( + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] + ) + + self_and_cross_attn_out = attn(**forward_args) + + # only self attention + + torch.manual_seed(0) + + constructor_args = self.get_constructor_arguments(only_cross_attention=True) + attn = Attention(**constructor_args) + + self.assertTrue(attn.to_k is None) + self.assertTrue(attn.to_v is None) + + forward_args = self.get_forward_arguments( + query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"] + ) + + only_cross_attn_out = attn(**forward_args) + + self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index b814f5f88a30..d3a3d5cfc9a0 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -116,7 +116,7 @@ def test_output_pretrained(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - num_features = model.in_channels + num_features = model.config.in_channels seq_len = 16 noise = torch.randn((1, seq_len, num_features)).permute( 0, 2, 1 @@ -264,7 +264,7 @@ def test_output_pretrained(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - num_features = value_function.in_channels + num_features = value_function.config.in_channels seq_len = 14 noise = torch.randn((1, seq_len, num_features)).permute( 0, 2, 1 diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index ab6f12085e0f..17e08e0a426e 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -22,7 +22,7 @@ from parameterized import parameterized from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -41,7 +41,7 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim @@ -57,12 +57,13 @@ def create_lora_layers(model): lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs @@ -199,6 +200,74 @@ def test_model_with_use_linear_projection(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_model_with_cross_attention_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["cross_attention_dim"] = (32, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_simple_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + batch_size, _, _, sample_size = inputs_dict["sample"].shape + + init_dict["class_embed_type"] = "simple_projection" + init_dict["projection_class_embeddings_input_dim"] = sample_size + + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_class_embeddings_concat(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + batch_size, _, _, sample_size = inputs_dict["sample"].shape + + init_dict["class_embed_type"] = "simple_projection" + init_dict["projection_class_embeddings_input_dim"] = sample_size + init_dict["class_embeddings_concat"] = True + + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_model_attention_slicing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -310,26 +379,7 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + lora_attn_procs = create_lora_layers(model) # make sure we can set a list of attention processors model.set_attn_processor(lora_attn_procs) @@ -397,28 +447,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): old_sample = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - + lora_attn_procs = create_lora_layers(model) model.set_attn_processor(lora_attn_procs) with torch.no_grad(): @@ -450,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: @@ -485,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self): model = self.model_class(**init_dict) model.to(torch_device) - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - + lora_attn_procs = create_lora_layers(model, mock_weights=False) model.set_attn_processor(lora_attn_procs) # Saving as torch, properly reloads with directly filename with tempfile.TemporaryDirectory() as tmpdirname: @@ -531,7 +532,7 @@ def test_lora_on_off(self): with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_attn_processor(AttnProcessor()) + model.set_default_attn_processor() with torch.no_grad(): new_sample = model(**inputs_dict).sample diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index a92b8edd5378..c552b503af05 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import numpy as np import torch from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.utils import ( floats_tensor, logging, + skip_mps, torch_device, ) from diffusers.utils.import_utils import is_xformers_available @@ -34,10 +37,13 @@ torch.backends.cuda.matmul.allow_tf32 = False -def create_lora_layers(model): +def create_lora_layers(model, mock_weights: bool = True): lora_attn_procs = {} for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + has_cross_attention = name.endswith("attn2.processor") and not ( + name.startswith("transformer_in") or "temp_attentions" in name.split(".") + ) + cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -46,20 +52,25 @@ def create_lora_layers(model): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + elif name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * model.config.attention_head_dim lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 + if mock_weights: + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 return lora_attn_procs +@skip_mps class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet3DConditionModel @@ -86,19 +97,17 @@ def output_shape(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_out_channels": (32, 64, 64, 64), + "block_out_channels": (32, 64), "down_block_types": ( - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), "cross_attention_dim": 32, - "attention_head_dim": 4, + "attention_head_dim": 8, "out_channels": 4, "in_channels": 4, - "layers_per_block": 2, + "layers_per_block": 1, "sample_size": 32, } inputs_dict = self.dummy_input @@ -119,12 +128,11 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - # Overriding because `block_out_channels` needs to be different for this model. + # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 - init_dict["block_out_channels"] = (32, 64, 64, 64) model = self.model_class(**init_dict) model.to(torch_device) @@ -191,23 +199,173 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_processors(self): + def test_lora_processors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + + # make sure we can set a list of attention processors + model.set_attn_processor(lora_attn_procs) + model.to(torch_device) + + # test that attn processors can be set to itself + model.set_attn_processor(model.attn_processors) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 1e-4 + assert (sample3 - sample4).abs().max() < 1e-4 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 + + def test_lora_save_load(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_load_safetensors(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_safetensors_load_torch(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # (`attn_processors`) needs to be implemented in this model for this test. - # def test_lora_save_load(self): + init_dict["attention_head_dim"] = 8 - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_load_safetensors(self): + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) - # (`attn_processors`) needs to be implemented for this test in the model. - # def test_lora_save_safetensors_load_torch(self): + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + + def test_lora_save_torch_force_load_safetensors_error(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - # (`attn_processors`) needs to be implemented for this test. - # def test_lora_save_torch_force_load_safetensors_error(self): + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = create_lora_layers(model, mock_weights=False) + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + + def test_lora_on_off(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_attn_processor(AttnProcessor()) + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample - # (`attn_processors`) needs to be added for this test. - # def test_lora_on_off(self): + assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 1e-4 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), diff --git a/tests/pipeline_params.py b/tests/pipeline_params.py index 2703801d4a7d..a0ac6c641c0b 100644 --- a/tests/pipeline_params.py +++ b/tests/pipeline_params.py @@ -102,3 +102,20 @@ UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) + +TEXT_TO_AUDIO_PARAMS = frozenset( + [ + "prompt", + "audio_length_in_s", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + "cross_attention_kwargs", + ] +) + +TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) +TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"]) + +TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index 939632943405..144107ec1c97 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -166,7 +166,7 @@ def test_stable_diffusion_img2img_default_case(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4115, 0.3870, 0.4089, 0.4807, 0.4668, 0.4144, 0.4151, 0.4721, 0.4569]) + expected_slice = np.array([0.4427, 0.3731, 0.4249, 0.4941, 0.4546, 0.4148, 0.4193, 0.4666, 0.4499]) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-3 diff --git a/tests/pipelines/audio_diffusion/test_audio_diffusion.py b/tests/pipelines/audio_diffusion/test_audio_diffusion.py index ba389d9c936d..0eb6252410f5 100644 --- a/tests/pipelines/audio_diffusion/test_audio_diffusion.py +++ b/tests/pipelines/audio_diffusion/test_audio_diffusion.py @@ -115,8 +115,11 @@ def test_audio_diffusion(self): output = pipe(generator=generator, steps=4, return_dict=False) image_from_tuple = output[0][0] - assert audio.shape == (1, (self.dummy_unet.sample_size[1] - 1) * mel.hop_length) - assert image.height == self.dummy_unet.sample_size[0] and image.width == self.dummy_unet.sample_size[1] + assert audio.shape == (1, (self.dummy_unet.config.sample_size[1] - 1) * mel.hop_length) + assert ( + image.height == self.dummy_unet.config.sample_size[0] + and image.width == self.dummy_unet.config.sample_size[1] + ) image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10] image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10] expected_slice = np.array([69, 255, 255, 255, 0, 0, 77, 181, 12, 127]) @@ -133,14 +136,14 @@ def test_audio_diffusion(self): pipe.set_progress_bar_config(disable=None) np.random.seed(0) - raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].sample_size[1] - 1) * mel.hop_length,)) + raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].config.sample_size[1] - 1) * mel.hop_length,)) generator = torch.Generator(device=device).manual_seed(42) output = pipe(raw_audio=raw_audio, generator=generator, start_step=5, steps=10) image = output.images[0] assert ( - image.height == self.dummy_vqvae_and_unet[0].sample_size[0] - and image.width == self.dummy_vqvae_and_unet[0].sample_size[1] + image.height == self.dummy_vqvae_and_unet[0].config.sample_size[0] + and image.width == self.dummy_vqvae_and_unet[0].config.sample_size[1] ) image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10] expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121]) @@ -183,8 +186,8 @@ def test_audio_diffusion(self): audio = output.audios[0] image = output.images[0] - assert audio.shape == (1, (pipe.unet.sample_size[1] - 1) * pipe.mel.hop_length) - assert image.height == pipe.unet.sample_size[0] and image.width == pipe.unet.sample_size[1] + assert audio.shape == (1, (pipe.unet.config.sample_size[1] - 1) * pipe.mel.hop_length) + assert image.height == pipe.unet.config.sample_size[0] and image.width == pipe.unet.config.sample_size[1] image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10] expected_slice = np.array([151, 167, 154, 144, 122, 134, 121, 105, 70, 26]) diff --git a/tests/pipelines/audioldm/__init__.py b/tests/pipelines/audioldm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py new file mode 100644 index 000000000000..10de5440eb00 --- /dev/null +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -0,0 +1,416 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import unittest + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ( + ClapTextConfig, + ClapTextModelWithProjection, + RobertaTokenizer, + SpeechT5HifiGan, + SpeechT5HifiGanConfig, +) + +from diffusers import ( + AudioLDMPipeline, + AutoencoderKL, + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.utils import slow, torch_device + +from ...pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AudioLDMPipeline + params = TEXT_TO_AUDIO_PARAMS + batch_params = TEXT_TO_AUDIO_BATCH_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_waveforms_per_prompt", + "generator", + "latents", + "output_type", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=(32, 64), + class_embed_type="simple_projection", + projection_class_embeddings_input_dim=32, + class_embeddings_concat=True, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=1, + out_channels=1, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = ClapTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = ClapTextModelWithProjection(text_encoder_config) + tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77) + + vocoder_config = SpeechT5HifiGanConfig( + model_in_dim=8, + sampling_rate=16000, + upsample_initial_channel=16, + upsample_rates=[2, 2], + upsample_kernel_sizes=[4, 4], + resblock_kernel_sizes=[3, 7], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]], + normalize_before=False, + ) + + vocoder = SpeechT5HifiGan(vocoder_config) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "vocoder": vocoder, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A hammer hitting a wooden surface", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + } + return inputs + + def test_audioldm_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = audioldm_pipe(**inputs) + audio = output.audios[0] + + assert audio.ndim == 1 + assert len(audio) == 256 + + audio_slice = audio[:10] + expected_slice = np.array( + [-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033] + ) + + assert np.abs(audio_slice - expected_slice).max() < 1e-2 + + def test_audioldm_prompt_embeds(self): + components = self.get_dummy_components() + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = audioldm_pipe(**inputs) + audio_1 = output.audios[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = audioldm_pipe.tokenizer( + prompt, + padding="max_length", + max_length=audioldm_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputs = text_inputs["input_ids"].to(torch_device) + + prompt_embeds = audioldm_pipe.text_encoder( + text_inputs, + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + inputs["prompt_embeds"] = prompt_embeds + + # forward + output = audioldm_pipe(**inputs) + audio_2 = output.audios[0] + + assert np.abs(audio_1 - audio_2).max() < 1e-2 + + def test_audioldm_negative_prompt_embeds(self): + components = self.get_dummy_components() + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = audioldm_pipe(**inputs) + audio_1 = output.audios[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + embeds = [] + for p in [prompt, negative_prompt]: + text_inputs = audioldm_pipe.tokenizer( + p, + padding="max_length", + max_length=audioldm_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputs = text_inputs["input_ids"].to(torch_device) + + text_embeds = audioldm_pipe.text_encoder( + text_inputs, + ) + text_embeds = text_embeds.text_embeds + # additional L_2 normalization over each hidden-state + text_embeds = F.normalize(text_embeds, dim=-1) + + embeds.append(text_embeds) + + inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds + + # forward + output = audioldm_pipe(**inputs) + audio_2 = output.audios[0] + + assert np.abs(audio_1 - audio_2).max() < 1e-2 + + def test_audioldm_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = PNDMScheduler(skip_prk_steps=True) + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + negative_prompt = "egg cracking" + output = audioldm_pipe(**inputs, negative_prompt=negative_prompt) + audio = output.audios[0] + + assert audio.ndim == 1 + assert len(audio) == 256 + + audio_slice = audio[:10] + expected_slice = np.array( + [-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032] + ) + + assert np.abs(audio_slice - expected_slice).max() < 1e-2 + + def test_audioldm_num_waveforms_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = PNDMScheduler(skip_prk_steps=True) + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(device) + audioldm_pipe.set_progress_bar_config(disable=None) + + prompt = "A hammer hitting a wooden surface" + + # test num_waveforms_per_prompt=1 (default) + audios = audioldm_pipe(prompt, num_inference_steps=2).audios + + assert audios.shape == (1, 256) + + # test num_waveforms_per_prompt=1 (default) for batch of prompts + batch_size = 2 + audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios + + assert audios.shape == (batch_size, 256) + + # test num_waveforms_per_prompt for single prompt + num_waveforms_per_prompt = 2 + audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios + + assert audios.shape == (num_waveforms_per_prompt, 256) + + # test num_waveforms_per_prompt for batch of prompts + batch_size = 2 + audios = audioldm_pipe( + [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt + ).audios + + assert audios.shape == (batch_size * num_waveforms_per_prompt, 256) + + def test_audioldm_audio_length_in_s(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate + + inputs = self.get_dummy_inputs(device) + output = audioldm_pipe(audio_length_in_s=0.016, **inputs) + audio = output.audios[0] + + assert audio.ndim == 1 + assert len(audio) / vocoder_sampling_rate == 0.016 + + output = audioldm_pipe(audio_length_in_s=0.032, **inputs) + audio = output.audios[0] + + assert audio.ndim == 1 + assert len(audio) / vocoder_sampling_rate == 0.032 + + def test_audioldm_vocoder_model_in_dim(self): + components = self.get_dummy_components() + audioldm_pipe = AudioLDMPipeline(**components) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + prompt = ["hey"] + + output = audioldm_pipe(prompt, num_inference_steps=1) + audio_shape = output.audios.shape + assert audio_shape == (1, 256) + + config = audioldm_pipe.vocoder.config + config.model_in_dim *= 2 + audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device) + output = audioldm_pipe(prompt, num_inference_steps=1) + audio_shape = output.audios.shape + # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram + assert audio_shape == (1, 256) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(test_mean_pixel_difference=False) + + +@slow +# @require_torch_gpu +class AudioLDMPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): + generator = torch.Generator(device=generator_device).manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16)) + latents = torch.from_numpy(latents).to(device=device, dtype=dtype) + inputs = { + "prompt": "A hammer hitting a wooden surface", + "latents": latents, + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 2.5, + } + return inputs + + def test_audioldm(self): + audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm") + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 25 + audio = audioldm_pipe(**inputs).audios[0] + + assert audio.ndim == 1 + assert len(audio) == 81920 + + audio_slice = audio[77230:77240] + expected_slice = np.array( + [-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315] + ) + max_diff = np.abs(expected_slice - audio_slice).max() + assert max_diff < 1e-2 + + def test_audioldm_lms(self): + audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm") + audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config) + audioldm_pipe = audioldm_pipe.to(torch_device) + audioldm_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + audio = audioldm_pipe(**inputs).audios[0] + + assert audio.ndim == 1 + assert len(audio) == 81920 + + audio_slice = audio[27780:27790] + expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212]) + max_diff = np.abs(expected_slice - audio_slice).max() + assert max_diff < 1e-2 diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index 8e5b3aba9ecb..947fd3cbf43d 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -20,7 +20,7 @@ import torch from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel -from diffusers.utils import load_numpy, slow +from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from ...pipeline_params import ( @@ -92,12 +92,19 @@ def test_inference(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 16, 16, 3)) - expected_slice = np.array([0.4380, 0.4141, 0.5159, 0.0000, 0.4282, 0.6680, 0.5485, 0.2545, 0.6719]) + expected_slice = np.array([0.2946, 0.6601, 0.4329, 0.3296, 0.4144, 0.5319, 0.7273, 0.5013, 0.4457]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(relax_max_difference=True) + self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) @require_torch_gpu @@ -123,7 +130,7 @@ def test_dit_256(self): expected_image = load_numpy( f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy" ) - assert np.abs((expected_image - image).max()) < 1e-3 + assert np.abs((expected_image - image).max()) < 1e-2 def test_dit_512(self): pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512") diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index 3f2dbe5cec2a..2ff7feda6317 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -125,7 +125,7 @@ def test_inference_text2img(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 16, 16, 3) - expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332]) + expected_slice = np.array([0.6101, 0.6156, 0.5622, 0.4895, 0.6661, 0.3804, 0.5748, 0.6136, 0.5014]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index 81d1989200ac..14b045d6c480 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -129,7 +129,7 @@ def test_paint_by_example_inpaint(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4701, 0.5555, 0.3994, 0.5107, 0.5691, 0.4517, 0.5125, 0.4769, 0.4539]) + expected_slice = np.array([0.4686, 0.5687, 0.4007, 0.5218, 0.5741, 0.4482, 0.4940, 0.4629, 0.4503]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index b312c8184390..ba42b1fe9c5f 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -154,7 +154,7 @@ def test_semantic_diffusion_ddim(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) + expected_slice = np.array([0.5753, 0.6114, 0.5001, 0.5034, 0.5470, 0.4729, 0.4971, 0.4867, 0.4867]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -200,7 +200,7 @@ def test_semantic_diffusion_pndm(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) + expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/spectrogram_diffusion/__init__.py b/tests/pipelines/spectrogram_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py b/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py new file mode 100644 index 000000000000..594d7c598f75 --- /dev/null +++ b/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import DDPMScheduler, MidiProcessor, SpectrogramDiffusionPipeline +from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder +from diffusers.utils import require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import require_note_seq, require_onnxruntime + +from ...pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +MIDI_FILE = "./tests/fixtures/elise_format0.mid" + + +class SpectrogramDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SpectrogramDiffusionPipeline + required_optional_params = PipelineTesterMixin.required_optional_params - { + "callback", + "latents", + "callback_steps", + "output_type", + "num_images_per_prompt", + } + test_attention_slicing = False + test_cpu_offload = False + batch_params = TOKENS_TO_AUDIO_GENERATION_PARAMS + params = TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + notes_encoder = SpectrogramNotesEncoder( + max_length=2048, + vocab_size=1536, + d_model=768, + dropout_rate=0.1, + num_layers=1, + num_heads=1, + d_kv=4, + d_ff=2048, + feed_forward_proj="gated-gelu", + ) + + continuous_encoder = SpectrogramContEncoder( + input_dims=128, + targets_context_length=256, + d_model=768, + dropout_rate=0.1, + num_layers=1, + num_heads=1, + d_kv=4, + d_ff=2048, + feed_forward_proj="gated-gelu", + ) + + decoder = T5FilmDecoder( + input_dims=128, + targets_length=256, + max_decoder_noise_time=20000.0, + d_model=768, + num_layers=1, + num_heads=1, + d_kv=4, + d_ff=2048, + dropout_rate=0.1, + ) + + scheduler = DDPMScheduler() + + components = { + "notes_encoder": notes_encoder.eval(), + "continuous_encoder": continuous_encoder.eval(), + "decoder": decoder.eval(), + "scheduler": scheduler, + "melgan": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "input_tokens": [ + [1134, 90, 1135, 1133, 1080, 112, 1132, 1080, 1133, 1079, 133, 1132, 1079, 1133, 1] + [0] * 2033 + ], + "generator": generator, + "num_inference_steps": 4, + "output_type": "mel", + } + return inputs + + def test_spectrogram_diffusion(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = SpectrogramDiffusionPipeline(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + mel = output.audios + + mel_slice = mel[0, -3:, -3:] + + assert mel_slice.shape == (3, 3) + expected_slice = np.array( + [-11.512925, -4.788215, -0.46172905, -2.051715, -10.539147, -10.970963, -9.091634, 4.0, 4.0] + ) + assert np.abs(mel_slice.flatten() - expected_slice).max() < 1e-2 + + @skip_mps + def test_save_load_local(self): + return super().test_save_load_local() + + @skip_mps + def test_dict_tuple_outputs_equivalent(self): + return super().test_dict_tuple_outputs_equivalent() + + @skip_mps + def test_save_load_optional_components(self): + return super().test_save_load_optional_components() + + @skip_mps + def test_attention_slicing_forward_pass(self): + return super().test_attention_slicing_forward_pass() + + def test_inference_batch_single_identical(self): + pass + + def test_inference_batch_consistent(self): + pass + + @skip_mps + def test_progress_bar(self): + return super().test_progress_bar() + + +@slow +@require_torch_gpu +@require_onnxruntime +@require_note_seq +class PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_callback(self): + # TODO - test that pipeline can decode tokens in a callback + # so that music can be played live + device = torch_device + + pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + melgan = pipe.melgan + pipe.melgan = None + + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + def callback(step, mel_output): + # decode mel to audio + audio = melgan(input_features=mel_output.astype(np.float32))[0] + assert len(audio[0]) == 81920 * (step + 1) + # simulate that audio is played + return audio + + processor = MidiProcessor() + input_tokens = processor(MIDI_FILE) + + input_tokens = input_tokens[:3] + generator = torch.manual_seed(0) + pipe(input_tokens, num_inference_steps=5, generator=generator, callback=callback, output_type="mel") + + def test_spectrogram_fast(self): + device = torch_device + + pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + processor = MidiProcessor() + + input_tokens = processor(MIDI_FILE) + # just run two denoising loops + input_tokens = input_tokens[:2] + + generator = torch.manual_seed(0) + output = pipe(input_tokens, num_inference_steps=2, generator=generator) + + audio = output.audios[0] + + assert abs(np.abs(audio).sum() - 3612.841) < 1e-1 + + def test_spectrogram(self): + device = torch_device + + pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + processor = MidiProcessor() + + input_tokens = processor(MIDI_FILE) + + # just run 4 denoising loops + input_tokens = input_tokens[:4] + + generator = torch.manual_seed(0) + output = pipe(input_tokens, num_inference_steps=100, generator=generator) + + audio = output.audios[0] + assert abs(np.abs(audio).sum() - 9389.1111) < 5e-2 diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index 74783faae421..3a5f9379ae50 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -133,6 +133,76 @@ def test_pipeline_dpm_multistep(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_prompt_embeds(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs() + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_inputs = text_inputs["input_ids"] + + prompt_embeds = pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0] + + inputs["prompt_embeds"] = prompt_embeds + + # forward + output = pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + def test_stable_diffusion_negative_prompt_embeds(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs() + prompt = 3 * [inputs.pop("prompt")] + + embeds = [] + for p in [prompt, negative_prompt]: + text_inputs = pipe.tokenizer( + p, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_inputs = text_inputs["input_ids"] + + embeds.append(pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0]) + + inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds + + # forward + output = pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + @nightly @require_onnxruntime diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py index 06e75d035d04..e1aa2f6dc0a1 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -81,7 +81,7 @@ def test_pipeline_pndm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.61710, 0.53390, 0.49310, 0.55622, 0.50982, 0.58240, 0.50716, 0.38629, 0.46856]) + expected_slice = np.array([0.61737, 0.54642, 0.53183, 0.54465, 0.52742, 0.60525, 0.49969, 0.40655, 0.48154]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 33ef9368586e..79796afdf597 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -21,6 +21,7 @@ import numpy as np import torch +from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( @@ -35,7 +36,6 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu @@ -135,7 +135,7 @@ def test_stable_diffusion_ddim(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5643, 0.6017, 0.4799, 0.5267, 0.5584, 0.4641, 0.5159, 0.4963, 0.4791]) + expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -282,7 +282,7 @@ def test_stable_diffusion_pndm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5094, 0.5674, 0.4667, 0.5125, 0.5696, 0.4674, 0.5277, 0.4964, 0.4945]) + expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -322,19 +322,7 @@ def test_stable_diffusion_k_lms(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [ - 0.47082293033599854, - 0.5371589064598083, - 0.4562119245529175, - 0.5220914483070374, - 0.5733777284622192, - 0.4795039892196655, - 0.5465868711471558, - 0.5074326395988464, - 0.5042197108268738, - ] - ) + expected_slice = np.array([0.4873, 0.5443, 0.4845, 0.5004, 0.5549, 0.4850, 0.5191, 0.4941, 0.5065]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -353,19 +341,7 @@ def test_stable_diffusion_k_euler_ancestral(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [ - 0.4707113206386566, - 0.5372191071510315, - 0.4563021957874298, - 0.5220003724098206, - 0.5734264850616455, - 0.4794946610927582, - 0.5463782548904419, - 0.5074145197868347, - 0.504422664642334, - ] - ) + expected_slice = np.array([0.4872, 0.5444, 0.4846, 0.5003, 0.5549, 0.4850, 0.5189, 0.4941, 0.5067]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -384,19 +360,7 @@ def test_stable_diffusion_k_euler(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [ - 0.47082313895225525, - 0.5371587872505188, - 0.4562119245529175, - 0.5220913887023926, - 0.5733776688575745, - 0.47950395941734314, - 0.546586811542511, - 0.5074326992034912, - 0.5042197108268738, - ] - ) + expected_slice = np.array([0.4873, 0.5443, 0.4845, 0.5004, 0.5549, 0.4850, 0.5191, 0.4941, 0.5065]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -468,19 +432,7 @@ def test_stable_diffusion_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [ - 0.5108221173286438, - 0.5688379406929016, - 0.4685141146183014, - 0.5098261833190918, - 0.5657756328582764, - 0.4631010890007019, - 0.5226285457611084, - 0.49129390716552734, - 0.4899061322212219, - ] - ) + expected_slice = np.array([0.5114, 0.5706, 0.4772, 0.5028, 0.5637, 0.4732, 0.5169, 0.4881, 0.4977]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -843,7 +795,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) outputs = pipe(**inputs) @@ -856,7 +808,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() @@ -887,6 +839,31 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): assert mem_bytes_slicing < mem_bytes_offloaded assert mem_bytes_slicing < 3 * 10**9 + def test_stable_diffusion_textual_inversion(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + pipe.to("cuda") + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 5e-2 + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py new file mode 100644 index 000000000000..268c01320177 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline +from diffusers.utils import is_flax_available, load_image, slow +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.jax_utils import replicate + from flax.training.common_utils import shard + + +@slow +@require_flax +class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def test_canny(self): + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16 + ) + pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 + ) + params["controlnet"] = controlnet_params + + prompts = "bird" + num_samples = jax.device_count() + prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + + canny_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + rng = jax.random.PRNGKey(0) + rng = jax.random.split(rng, jax.device_count()) + + p_params = replicate(params) + prompt_ids = shard(prompt_ids) + processed_image = shard(processed_image) + + images = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=50, + jit=True, + ).images + assert images.shape == (jax.device_count(), 1, 768, 512, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array( + [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078] + ) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + def test_pose(self): + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16 + ) + pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 + ) + params["controlnet"] = controlnet_params + + prompts = "Chef in the kitchen" + num_samples = jax.device_count() + prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + + pose_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png" + ) + processed_image = pipe.prepare_image_inputs([pose_image] * num_samples) + + rng = jax.random.PRNGKey(0) + rng = jax.random.split(rng, jax.device_count()) + + p_params = replicate(params) + prompt_ids = shard(prompt_ids) + processed_image = shard(processed_image) + + images = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=50, + jit=True, + ).images + assert images.shape == (jax.device_count(), 1, 768, 512, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array( + [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]] + ) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 01c2e22e4816..2a07ab64a36d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -119,7 +119,7 @@ def test_stable_diffusion_img_variation_default_case(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958]) + expected_slice = np.array([0.5239, 0.5723, 0.4796, 0.5049, 0.5550, 0.4685, 0.5329, 0.4891, 0.4921]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -139,7 +139,7 @@ def test_stable_diffusion_img_variation_multiple_images(self): image_slice = image[-1, -3:, -3:, -1] assert image.shape == (2, 64, 64, 3) - expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263]) + expected_slice = np.array([0.6892, 0.5637, 0.5836, 0.5771, 0.6254, 0.6409, 0.5580, 0.5569, 0.5289]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index e27f83fc04fe..69b92f685f25 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -138,7 +138,7 @@ def test_stable_diffusion_img2img_default_case(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) + expected_slice = np.array([0.4555, 0.3216, 0.4049, 0.4620, 0.4618, 0.4126, 0.4122, 0.4629, 0.4579]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -157,7 +157,7 @@ def test_stable_diffusion_img2img_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365]) + expected_slice = np.array([0.4593, 0.3408, 0.4232, 0.4749, 0.4476, 0.4115, 0.4357, 0.4733, 0.4663]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -176,7 +176,7 @@ def test_stable_diffusion_img2img_multiple_init_images(self): image_slice = image[-1, -3:, -3:, -1] assert image.shape == (2, 32, 32, 3) - expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689]) + expected_slice = np.array([0.4241, 0.5576, 0.5711, 0.4792, 0.4311, 0.5952, 0.5827, 0.5138, 0.5109]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -196,7 +196,7 @@ def test_stable_diffusion_img2img_k_lms(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203]) + expected_slice = np.array([0.4398, 0.4949, 0.4337, 0.6580, 0.5555, 0.4338, 0.5769, 0.5955, 0.5175]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 25b0c6ea1432..78e697fbbac3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -124,7 +124,7 @@ def test_stable_diffusion_pix2pix_default_case(self): image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813]) + expected_slice = np.array([0.7526, 0.3750, 0.4547, 0.6117, 0.5866, 0.5016, 0.4327, 0.5642, 0.4815]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -142,7 +142,7 @@ def test_stable_diffusion_pix2pix_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827]) + expected_slice = np.array([0.7511, 0.3642, 0.4553, 0.6236, 0.5797, 0.5013, 0.4343, 0.5611, 0.4831]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -165,7 +165,7 @@ def test_stable_diffusion_pix2pix_multiple_init_images(self): image_slice = image[-1, -3:, -3:, -1] assert image.shape == (2, 32, 32, 3) - expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607]) + expected_slice = np.array([0.5812, 0.5748, 0.5222, 0.5908, 0.5695, 0.7174, 0.6804, 0.5523, 0.5579]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -187,7 +187,7 @@ def test_stable_diffusion_pix2pix_euler(self): print(",".join([str(x) for x in slice])) assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846]) + expected_slice = np.array([0.7417, 0.3842, 0.4732, 0.5776, 0.5891, 0.5139, 0.4052, 0.5673, 0.4986]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py index 7869790c6218..546b1d21252c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py @@ -75,3 +75,32 @@ def test_stable_diffusion_2(self): expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112]) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1 + + def test_stable_diffusion_karras_sigmas(self): + sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + sd_pipe.set_scheduler("sample_dpmpp_2m") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=7.5, + num_inference_steps=15, + output_type="np", + use_karras_sigmas=True, + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py new file mode 100644 index 000000000000..1e11500c72b1 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + PNDMScheduler, + StableDiffusionModelEditingPipeline, + UNet2DConditionModel, +) +from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu, skip_mps + +from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +@skip_mps +class StableDiffusionModelEditingPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionModelEditingPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler() + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + generator = torch.manual_seed(seed) + inputs = { + "prompt": "A field of roses", + "generator": generator, + # Setting height and width to None to prevent OOMs on CPU. + "height": None, + "width": None, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def test_stable_diffusion_model_editing_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionModelEditingPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.4755, 0.5132, 0.4976, 0.3904, 0.3554, 0.4765, 0.5139, 0.5158, 0.4889]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_model_editing_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionModelEditingPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + negative_prompt = "french fries" + output = sd_pipe(**inputs, negative_prompt=negative_prompt) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.4992, 0.5101, 0.5004, 0.3949, 0.3604, 0.4735, 0.5216, 0.5204, 0.4913]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_model_editing_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = EulerAncestralDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + sd_pipe = StableDiffusionModelEditingPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.4747, 0.5372, 0.4779, 0.4982, 0.5543, 0.4816, 0.5238, 0.4904, 0.5027]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_model_editing_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = PNDMScheduler() + sd_pipe = StableDiffusionModelEditingPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + # the pipeline does not expect pndm so test if it raises error. + with self.assertRaises(ValueError): + _ = sd_pipe(**inputs).images + + +@slow +@require_torch_gpu +class StableDiffusionModelEditingSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, seed=0): + generator = torch.manual_seed(seed) + inputs = { + "prompt": "A field of roses", + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 7.5, + "output_type": "numpy", + } + return inputs + + def test_stable_diffusion_model_editing_default(self): + model_ckpt = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs() + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + + expected_slice = np.array( + [0.6749496, 0.6386453, 0.51443267, 0.66094905, 0.61921215, 0.5491332, 0.5744417, 0.58075106, 0.5174658] + ) + + assert np.abs(expected_slice - image_slice).max() < 1e-2 + + # make sure image changes after editing + pipe.edit_model("A pack of roses", "A pack of blue roses") + + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + + assert np.abs(expected_slice - image_slice).max() > 1e-1 + + def test_stable_diffusion_model_editing_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + model_ckpt = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") + pipe = StableDiffusionModelEditingPipeline.from_pretrained( + model_ckpt, scheduler=scheduler, safety_checker=None + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + inputs = self.get_inputs() + _ = pipe(**inputs) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 4.4 GB is allocated + assert mem_bytes < 4.4 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index af26e19cca73..de9e8a79fb34 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -119,7 +119,7 @@ def test_stable_diffusion_panorama_default_case(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5101, 0.5006, 0.4962, 0.3995, 0.3501, 0.4632, 0.5339, 0.525, 0.4878]) + expected_slice = np.array([0.4794, 0.5084, 0.4992, 0.3941, 0.3555, 0.4754, 0.5248, 0.5224, 0.4839]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -138,7 +138,7 @@ def test_stable_diffusion_panorama_negative_prompt(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5326, 0.5009, 0.5074, 0.4133, 0.371, 0.464, 0.5432, 0.5429, 0.4896]) + expected_slice = np.array([0.5029, 0.5075, 0.5002, 0.3965, 0.3584, 0.4746, 0.5271, 0.5273, 0.4877]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -158,9 +158,7 @@ def test_stable_diffusion_panorama_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [0.48235387, 0.5423796, 0.46016198, 0.5377287, 0.5803722, 0.4876525, 0.5515428, 0.5045897, 0.50709957] - ) + expected_slice = np.array([0.4934, 0.5455, 0.4847, 0.5022, 0.5572, 0.4833, 0.5207, 0.4952, 0.5051]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 46b93a0589ce..59c45d603b91 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -133,7 +133,7 @@ def test_stable_diffusion_pix2pix_zero_default_case(self): image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5184, 0.503, 0.4917, 0.4022, 0.3455, 0.464, 0.5324, 0.5323, 0.4894]) + expected_slice = np.array([0.4863, 0.5053, 0.5033, 0.4007, 0.3571, 0.4768, 0.5176, 0.5277, 0.4940]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -151,7 +151,7 @@ def test_stable_diffusion_pix2pix_zero_negative_prompt(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5464, 0.5072, 0.5012, 0.4124, 0.3624, 0.466, 0.5413, 0.5468, 0.4927]) + expected_slice = np.array([0.5177, 0.5097, 0.5047, 0.4076, 0.3667, 0.4767, 0.5238, 0.5307, 0.4958]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -170,7 +170,7 @@ def test_stable_diffusion_pix2pix_zero_euler(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5114, 0.5051, 0.5222, 0.5279, 0.5037, 0.5156, 0.4604, 0.4966, 0.504]) + expected_slice = np.array([0.5421, 0.5525, 0.6085, 0.5279, 0.4658, 0.5317, 0.4418, 0.4815, 0.5132]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -187,7 +187,7 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5185, 0.5027, 0.492, 0.401, 0.3445, 0.464, 0.5321, 0.5327, 0.4892]) + expected_slice = np.array([0.4861, 0.5053, 0.5038, 0.3994, 0.3562, 0.4768, 0.5172, 0.5280, 0.4938]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 481c265cbee4..7b607c8fdd36 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -32,7 +32,6 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu @@ -135,7 +134,7 @@ def test_stable_diffusion_ddim(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5649, 0.6022, 0.4804, 0.5270, 0.5585, 0.4643, 0.5159, 0.4963, 0.4793]) + expected_slice = np.array([0.5753, 0.6113, 0.5005, 0.5036, 0.5464, 0.4725, 0.4982, 0.4865, 0.4861]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -152,7 +151,7 @@ def test_stable_diffusion_pndm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946]) + expected_slice = np.array([0.5121, 0.5714, 0.4827, 0.5057, 0.5646, 0.4766, 0.5189, 0.4895, 0.4990]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -169,7 +168,7 @@ def test_stable_diffusion_k_lms(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + expected_slice = np.array([0.4865, 0.5439, 0.4840, 0.4995, 0.5543, 0.4846, 0.5199, 0.4942, 0.5061]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -186,7 +185,7 @@ def test_stable_diffusion_k_euler_ancestral(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046]) + expected_slice = np.array([0.4864, 0.5440, 0.4842, 0.4994, 0.5543, 0.4846, 0.5196, 0.4942, 0.5063]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -203,7 +202,7 @@ def test_stable_diffusion_k_euler(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + expected_slice = np.array([0.4865, 0.5439, 0.4840, 0.4995, 0.5543, 0.4846, 0.5199, 0.4942, 0.5061]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -410,7 +409,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "stabilityai/stable-diffusion-2-base", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) outputs = pipe(**inputs) @@ -423,7 +422,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self): "stabilityai/stable-diffusion-2-base", torch_dtype=torch.float16, ) - pipe.unet.set_attn_processor(AttnProcessor()) + pipe.unet.set_default_attn_processor() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 780abf304a46..90bb1461d351 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -132,9 +132,7 @@ def test_inference(self): image_slice = image[0, -3:, -3:, -1] self.assertEqual(image.shape, (1, 64, 64, 3)) - expected_slice = np.array( - [0.5644937, 0.60543084, 0.48239064, 0.5206757, 0.55623394, 0.46045133, 0.5100435, 0.48919064, 0.4759359] - ) + expected_slice = np.array([0.5743, 0.6081, 0.4975, 0.5021, 0.5441, 0.4699, 0.4988, 0.4841, 0.4851]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index c2ad239f6888..6b0205f3faeb 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -289,7 +289,7 @@ def test_stable_diffusion_depth2img_default_case(self): if torch_device == "mps": expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) else: - expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) + expected_slice = np.array([0.5435, 0.4992, 0.3783, 0.4411, 0.5842, 0.4654, 0.3786, 0.5077, 0.4655]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -308,9 +308,9 @@ def test_stable_diffusion_depth2img_negative_prompt(self): assert image.shape == (1, 32, 32, 3) if torch_device == "mps": - expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) - else: expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626]) + else: + expected_slice = np.array([0.6012, 0.4507, 0.3769, 0.4121, 0.5566, 0.4585, 0.3803, 0.5045, 0.4631]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -332,7 +332,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): if torch_device == "mps": expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) else: - expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674]) + expected_slice = np.array([0.6557, 0.6214, 0.6254, 0.5775, 0.4785, 0.5949, 0.5904, 0.4785, 0.4730]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -351,7 +351,7 @@ def test_stable_diffusion_depth2img_pil(self): if torch_device == "mps": expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) else: - expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716]) + expected_slice = np.array([0.5435, 0.4992, 0.3783, 0.4411, 0.5842, 0.4654, 0.3786, 0.5077, 0.4655]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -397,7 +397,7 @@ def test_stable_diffusion_depth2img_pipeline_default(self): image_slice = image[0, 253:256, 253:256, -1].flatten() assert image.shape == (1, 480, 640, 3) - expected_slice = np.array([0.9057, 0.9365, 0.9258, 0.8937, 0.8555, 0.8541, 0.8260, 0.7747, 0.7421]) + expected_slice = np.array([0.5435, 0.4992, 0.3783, 0.4411, 0.5842, 0.4654, 0.3786, 0.5077, 0.4655]) assert np.abs(expected_slice - image_slice).max() < 1e-4 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index b8e7b858130b..747809a4fb2e 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -154,7 +154,7 @@ def test_stable_diffusion_upscale(self): expected_height_width = low_res_image.size[0] * 4 assert image.shape == (1, expected_height_width, expected_height_width, 3) - expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606]) + expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 8aab5845741c..083640a87ba9 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -144,7 +144,7 @@ def test_stable_diffusion_v_pred_ddim(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.6424, 0.6109, 0.494, 0.5088, 0.4984, 0.4525, 0.5059, 0.5068, 0.4474]) + expected_slice = np.array([0.6569, 0.6525, 0.5142, 0.4968, 0.4923, 0.4601, 0.4996, 0.5041, 0.4544]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -193,7 +193,7 @@ def test_stable_diffusion_v_pred_k_euler(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776]) + expected_slice = np.array([0.5644, 0.6514, 0.5190, 0.5663, 0.5287, 0.4953, 0.5430, 0.5243, 0.4778]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index 2f393a66d166..c614fa48055e 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -154,7 +154,7 @@ def test_safe_diffusion_ddim(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) + expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -200,7 +200,7 @@ def test_stable_diffusion_pndm(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) + expected_slice = np.array([0.5125, 0.5716, 0.4828, 0.5060, 0.5650, 0.4768, 0.5185, 0.4895, 0.4993]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 1db8c3801007..907853394040 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -2,9 +2,10 @@ import random import unittest +import numpy as np import torch from transformers import ( - CLIPFeatureExtractor, + CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, @@ -16,7 +17,15 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_image, load_numpy, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_image, + load_numpy, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ...test_pipelines_common import ( @@ -36,8 +45,9 @@ def get_dummy_components(self): # image encoding components - feature_extractor = CLIPFeatureExtractor(crop_size=32, size=32) + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) + torch.manual_seed(0) image_encoder = CLIPVisionModelWithProjection( CLIPVisionConfig( hidden_size=embedder_hidden_size, @@ -110,16 +120,16 @@ def get_dummy_components(self): components = { # image encoding components "feature_extractor": feature_extractor, - "image_encoder": image_encoder, + "image_encoder": image_encoder.eval(), # image noising components - "image_normalizer": image_normalizer, + "image_normalizer": image_normalizer.eval(), "image_noising_scheduler": image_noising_scheduler, # regular denoising components "tokenizer": tokenizer, - "text_encoder": text_encoder, - "unet": unet, + "text_encoder": text_encoder.eval(), + "unet": unet.eval(), "scheduler": scheduler, - "vae": vae, + "vae": vae.eval(), } return components @@ -146,6 +156,24 @@ def get_dummy_inputs(self, device, seed=0, pil_image=True): "output_type": "np", } + @skip_mps + def test_image_embeds_none(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableUnCLIPImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs.update({"image_embeds": None}) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.3872, 0.7224, 0.5601, 0.4741, 0.6872, 0.5814, 0.4636, 0.3867, 0.5078]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass # because GPU undeterminism requires a looser check. def test_attention_slicing_forward_pass(self): @@ -197,7 +225,7 @@ def test_stable_unclip_l_img2img(self): pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) - output = pipe("anime turle", image=input_image, generator=generator, output_type="np") + output = pipe(input_image, "anime turle", generator=generator, output_type="np") image = output.images[0] @@ -225,7 +253,7 @@ def test_stable_unclip_h_img2img(self): pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) - output = pipe("anime turle", image=input_image, generator=generator, output_type="np") + output = pipe(input_image, "anime turle", generator=generator, output_type="np") image = output.images[0] @@ -251,8 +279,8 @@ def test_stable_unclip_img2img_pipeline_with_sequential_cpu_offloading(self): pipe.enable_sequential_cpu_offload() _ = pipe( + input_image, "anime turtle", - image=input_image, num_inference_steps=2, output_type="np", ) diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index eb43a360653a..438e685a443c 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -35,6 +35,7 @@ torch.backends.cuda.matmul.allow_tf32 = False +@skip_mps class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = TextToVideoSDPipeline params = TEXT_TO_IMAGE_PARAMS @@ -134,7 +135,7 @@ def test_text_to_video_default_case(self): image_slice = frames[0][-3:, -3:, -1] assert frames[0].shape == (64, 64, 3) - expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114]) + expected_slice = np.array([158.0, 160.0, 153.0, 125.0, 100.0, 121.0, 111.0, 93.0, 113.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -155,12 +156,12 @@ def test_inference_batch_single_identical(self): def test_num_images_per_prompt(self): pass - @skip_mps def test_progress_bar(self): return super().test_progress_bar() @slow +@skip_mps class TextToVideoSDPipelineSlowTests(unittest.TestCase): def test_full_model(self): expected_video = load_numpy( diff --git a/tests/pipelines/text_to_video/test_text_to_video_zero.py b/tests/pipelines/text_to_video/test_text_to_video_zero.py new file mode 100644 index 000000000000..45bb93fbd9c6 --- /dev/null +++ b/tests/pipelines/text_to_video/test_text_to_video_zero.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DDIMScheduler, TextToVideoZeroPipeline +from diffusers.utils import load_pt, require_torch_gpu, slow + +from ...test_pipelines_common import assert_mean_pixel_difference + + +@slow +@require_torch_gpu +class TextToVideoZeroPipelineSlowTests(unittest.TestCase): + def test_full_model(self): + model_id = "runwayml/stable-diffusion-v1-5" + pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + generator = torch.Generator(device="cuda").manual_seed(0) + + prompt = "A bear is playing a guitar on Times Square" + result = pipe(prompt=prompt, generator=generator).images + + expected_result = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt" + ) + + assert_mean_pixel_difference(result, expected_result) diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index ff32ac5f9aaf..3cacb0bcad0b 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa "decoder_num_inference_steps", "super_res_num_inference_steps", ] + test_xformers_attention = False @property def text_embedder_hidden_size(self): @@ -420,7 +421,12 @@ class DummyScheduler: def test_attention_slicing_forward_pass(self): test_max_difference = torch_device == "cpu" - self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) + # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor + expected_max_diff = 1e-2 + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, expected_max_diff=expected_max_diff + ) # Overriding PipelineTesterMixin::test_inference_batch_single_identical # because UnCLIP undeterminism requires a looser check. diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 6769240db905..d97a7b2f6564 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -143,7 +143,7 @@ def test_vq_diffusion(self): assert image.shape == (1, 24, 24, 3) - expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880]) + expected_slice = np.array([0.6551, 0.6168, 0.5008, 0.5676, 0.5659, 0.4295, 0.6073, 0.5599, 0.4992]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -187,7 +187,7 @@ def test_vq_diffusion_classifier_free_sampling(self): assert image.shape == (1, 24, 24, 3) - expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912]) + expected_slice = np.array([0.6693, 0.6075, 0.4959, 0.5701, 0.5583, 0.4333, 0.6171, 0.5684, 0.4988]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 295bbe882746..c1593bae3908 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -201,7 +201,7 @@ def test_full_loop_no_noise_thres(self): sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.6405) < 1e-3 + assert abs(result_mean.item() - 1.1364) < 1e-3 def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") @@ -209,6 +209,12 @@ def test_full_loop_with_v_prediction(self): assert abs(result_mean.item() - 0.2251) < 1e-3 + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2096) < 1e-3 + def test_switch(self): # make sure that iterating over schedulers with same config names gives same results # for defaults @@ -243,3 +249,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 4d521b0075e1..aa46ef31885a 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -117,3 +117,30 @@ def test_full_loop_device(self): assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_mean.item() - 0.0131) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 124.52299499511719) < 1e-2 + assert abs(result_mean.item() - 0.16213932633399963) < 1e-3 diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 6154c8e2d625..62cffc67388c 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -229,3 +229,11 @@ def test_fp16_support(self): sample = scheduler.step(residual, t, sample).prev_sample assert sample.dtype == torch.float16 + + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index d0e2102b539e..db0d6c78d902 100644 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -411,9 +411,8 @@ def test_spatial_transformer_cross_attention_dim(self): assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] - expected_slice = torch.tensor( - [-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471], device=torch_device + [0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) @@ -445,14 +444,14 @@ def test_spatial_transformer_timestep(self): output_slice_1 = attention_scores_1[0, -1, -3:, -3:] output_slice_2 = attention_scores_2[0, -1, -3:, -3:] - expected_slice_1 = torch.tensor( - [-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device + expected_slice = torch.tensor( + [-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device ) expected_slice_2 = torch.tensor( - [-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device + [-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device ) - assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3) + assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3) def test_spatial_transformer_dropout(self): diff --git a/tests/test_lora_layers.py b/tests/test_lora_layers.py new file mode 100644 index 000000000000..9bcdc5d93301 --- /dev/null +++ b/tests/test_lora_layers.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import torch +import torch.nn as nn +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device + + +def create_unet_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + return text_encoder_lora_layers + + +class LoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "text_encoder_lora_layers": text_encoder_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_safetensors(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_legacy(self): + pipeline_components, lora_components = self.get_dummy_components() + unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + unet = sd_pipe.unet + unet.set_attn_processor(unet_lora_attn_procs) + unet.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e880950a7914..40aba3b24967 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,9 +25,9 @@ from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel -from diffusers.models.attention_processor import AttnProcessor from diffusers.training_utils import EMAModel from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -100,22 +100,46 @@ def test_one_request_upon_cached(self): diffusers.utils.import_utils._safetensors_available = True + def test_weight_overwrite(self): + with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: + UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + ) + + # make sure that error message states what keys are missing + assert "Cannot load" in str(error_context.exception) + + with tempfile.TemporaryDirectory() as tmpdirname: + model = UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + low_cpu_mem_usage=False, + ignore_mismatched_sizes=True, + ) + + assert model.config.in_channels == 9 + class ModelTesterMixin: def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if hasattr(model, "set_attn_processor"): - model.set_attn_processor(AttnProcessor()) + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) - if hasattr(new_model, "set_attn_processor"): - new_model.set_attn_processor(AttnProcessor()) + if hasattr(new_model, "set_default_attn_processor"): + new_model.set_default_attn_processor() new_model.to(torch_device) with torch.no_grad(): @@ -135,16 +159,16 @@ def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if hasattr(model, "set_attn_processor"): - model.set_attn_processor(AttnProcessor()) + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, variant="fp16") new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") - if hasattr(new_model, "set_attn_processor"): - new_model.set_attn_processor(AttnProcessor()) + if hasattr(new_model, "set_default_attn_processor"): + new_model.set_default_attn_processor() # non-variant cannot be loaded with self.assertRaises(OSError) as error_context: @@ -168,6 +192,21 @@ def test_from_save_pretrained_variant(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + @require_torch_gpu + def test_from_save_pretrained_dynamo(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == self.model_class + def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 9f0c9b1a4e19..a5d70b01d453 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -31,7 +31,7 @@ from parameterized import parameterized from PIL import Image from requests.exceptions import HTTPError -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, @@ -54,7 +54,16 @@ logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device +from diffusers.utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + floats_tensor, + is_flax_available, + nightly, + require_torch_2, + slow, + torch_device, +) from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu @@ -69,9 +78,7 @@ def test_one_request_upon_cached(self): with tempfile.TemporaryDirectory() as tmpdirname: with requests_mock.mock(real_http=True) as m: - DiffusionPipeline.download( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname - ) + DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe", cache_dir=tmpdirname) download_requests = [r.method for r in m.request_history] assert download_requests.count("HEAD") == 15, "15 calls to files" @@ -92,6 +99,55 @@ def test_one_request_upon_cached(self): len(cache_requests) == 2 ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + def test_less_downloads_passed_object(self): + with tempfile.TemporaryDirectory() as tmpdirname: + cached_folder = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + # make sure safety checker is not downloaded + assert "safety_checker" not in os.listdir(cached_folder) + + # make sure rest is downloaded + assert "unet" in os.listdir(cached_folder) + assert "tokenizer" in os.listdir(cached_folder) + assert "vae" in os.listdir(cached_folder) + assert "model_index.json" in os.listdir(cached_folder) + assert "scheduler" in os.listdir(cached_folder) + assert "feature_extractor" in os.listdir(cached_folder) + + def test_less_downloads_passed_object_calls(self): + # TODO: For some reason this test fails on MPS where no HEAD call is made. + if torch_device == "mps": + return + + with tempfile.TemporaryDirectory() as tmpdirname: + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + download_requests = [r.method for r in m.request_history] + # 15 - 2 because no call to config or model file for `safety_checker` + assert download_requests.count("HEAD") == 13, "13 calls to files" + # 17 - 2 because no call to config or model file for `safety_checker` + assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json" + assert ( + len(download_requests) == 28 + ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json" + + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + cache_requests = [r.method for r in m.request_history] + assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD" + assert cache_requests.count("GET") == 1, "model info is only GET" + assert ( + len(cache_requests) == 2 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: # pipeline has Flax weights @@ -156,6 +212,54 @@ def test_download_safetensors(self): # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack assert not any(f.endswith(".bin") for f in files) + def test_download_safetensors_index(self): + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + cache_dir=tmpdirname, + use_safetensors=True, + variant=variant, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a safetensors file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder + if variant is None: + assert not any("fp16" in f for f in files) + else: + model_files = [f for f in files if "safetensors" in f] + assert all("fp16" in f for f in model_files) + + assert len([f for f in files if ".safetensors" in f]) == 8 + assert not any(".bin" in f for f in files) + + def test_download_bin_index(self): + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + cache_dir=tmpdirname, + use_safetensors=False, + variant=variant, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a safetensors file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder + if variant is None: + assert not any("fp16" in f for f in files) + else: + model_files = [f for f in files if "bin" in f] + assert all("fp16" in f for f in model_files) + + assert len([f for f in files if ".bin" in f]) == 8 + assert not any(".safetensors" in f for f in files) + def test_download_no_safety_checker(self): prompt = "hello" pipe = StableDiffusionPipeline.from_pretrained( @@ -353,6 +457,124 @@ def test_download_broken_variant(self): diffusers.utils.import_utils._safetensors_available = True + def test_local_save_load_index(self): + prompt = "hello" + for variant in [None, "fp16"]: + for use_safe in [True, False]: + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe-indexes", + variant=variant, + use_safetensors=use_safe, + safety_checker=None, + ) + pipe = pipe.to(torch_device) + generator = torch.manual_seed(0) + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe_2 = StableDiffusionPipeline.from_pretrained( + tmpdirname, safe_serialization=use_safe, variant=variant + ) + pipe_2 = pipe_2.to(torch_device) + + generator = torch.manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + + def test_text_inversion_download(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe = pipe.to(torch_device) + + num_tokens = len(pipe.tokenizer) + + # single token load local + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<*>": torch.ones((32,))} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname) + + token = pipe.tokenizer.convert_tokens_to_ids("<*>") + assert token == num_tokens, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>" + + prompt = "hey <*>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # single token load local with weight name + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<**>": 2 * torch.ones((1, 32))} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin") + + token = pipe.tokenizer.convert_tokens_to_ids("<**>") + assert token == num_tokens + 1, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>" + + prompt = "hey <**>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi token load + with tempfile.TemporaryDirectory() as tmpdirname: + ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])} + torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin")) + + pipe.load_textual_inversion(tmpdirname) + + token = pipe.tokenizer.convert_tokens_to_ids("<***>") + token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2") + + assert token == num_tokens + 2, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2" + + prompt = "hey <***>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi token load a1111 + with tempfile.TemporaryDirectory() as tmpdirname: + ten = { + "string_to_param": { + "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) + }, + "name": "<****>", + } + torch.save(ten, os.path.join(tmpdirname, "a1111.bin")) + + pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin") + + token = pipe.tokenizer.convert_tokens_to_ids("<****>") + token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2") + + assert token == num_tokens + 5, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2" + + prompt = "hey <****>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): @@ -433,7 +655,7 @@ def test_local_custom_pipeline_file(self): def test_download_from_git(self): clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" - feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) + feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id) clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) pipeline = DiffusionPipeline.from_pretrained( @@ -453,6 +675,25 @@ def test_download_from_git(self): image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0] assert image.shape == (512, 512, 3) + def test_save_pipeline_change_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = DiffusionPipeline.from_pretrained(tmpdirname) + + assert pipe.scheduler.__class__.__name__ == "PNDMScheduler" + + # let's make sure that changing the scheduler is correctly reflected + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.save_pretrained(tmpdirname) + pipe = DiffusionPipeline.from_pretrained(tmpdirname) + + assert pipe.scheduler.__class__.__name__ == "DPMSolverMultistepScheduler" + class PipelineFastTests(unittest.TestCase): def tearDown(self): @@ -688,6 +929,46 @@ def test_set_scheduler(self): sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + def test_set_component_to_none(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + pipeline = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + generator = torch.Generator(device="cpu").manual_seed(0) + + prompt = "This is a flower" + + out_image = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + pipeline.feature_extractor = None + generator = torch.Generator(device="cpu").manual_seed(0) + out_image_2 = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + assert out_image.shape == (1, 64, 64, 3) + assert np.abs(out_image - out_image_2).max() < 1e-3 + def test_set_scheduler_consistency(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") @@ -966,9 +1247,41 @@ def test_from_save_pretrained(self): down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) - schedular = DDPMScheduler(num_train_timesteps=10) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @require_torch_2 + def test_from_save_pretrained_dynamo(self): + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) - ddpm = DDPMPipeline(model, schedular) + ddpm = DDPMPipeline(model, scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index a461930f3a83..294dad5ff0f1 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -28,7 +28,6 @@ import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard - from jax import pmap from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline @@ -70,22 +69,19 @@ def test_dummy_all_tpus(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 64, 64, 3) if jax.device_count() == 8: - assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3 - assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1 + assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3 + assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1 images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) - assert len(images_pil) == num_samples def test_stable_diffusion_v1_4(self): @@ -105,14 +101,12 @@ def test_stable_diffusion_v1_4(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: @@ -136,19 +130,17 @@ def test_stable_diffusion_v1_4_bfloat_16(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1 def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( @@ -176,8 +168,8 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1 + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1 def test_stable_diffusion_v1_4_bfloat_16_ddim(self): scheduler = FlaxDDIMScheduler( @@ -211,16 +203,58 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 + + def test_jax_memory_efficient_attention(self): + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples) + + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images.shape == (num_samples, 1, 512, 512, 3) + slice = images[2, 0, 256, 10:17, 1] + + # With memory efficient attention + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + use_memory_efficient_attention=True, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images_eff.shape == (num_samples, 1, 512, 512, 3) + slice_eff = images[2, 0, 256, 10:17, 1] + + # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum` + # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now. + assert abs(slice_eff - slice).max() < 1e-2 diff --git a/tests/test_unet_2d_blocks.py b/tests/test_unet_2d_blocks.py index e560240422ac..4d658f282932 100644 --- a/tests/test_unet_2d_blocks.py +++ b/tests/test_unet_2d_blocks.py @@ -57,7 +57,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_output(self): - expected_slice = [0.2440, -0.6953, -0.2140, -0.3874, 0.1966, 1.2077, 0.0441, -0.7718, 0.2800] + expected_slice = [0.2238, -0.7396, -0.2255, -0.3829, 0.1925, 1.1665, 0.0603, -0.7295, 0.1983] super().test_output(expected_slice) @@ -175,7 +175,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_output(self): - expected_slice = [0.1879, 2.2653, 0.5987, 1.1568, -0.8454, -1.6109, -0.8919, 0.8306, 1.6758] + expected_slice = [0.0187, 2.4220, 0.4484, 1.1203, -0.6121, -1.5122, -0.8270, 0.7851, 1.8335] super().test_output(expected_slice) @@ -237,7 +237,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_output(self): - expected_slice = [-0.2796, -0.4364, -0.1067, -0.2693, 0.1894, 0.3869, -0.3470, 0.4584, 0.5091] + expected_slice = [-0.1403, -0.3515, -0.0420, -0.1425, 0.3167, 0.5094, -0.2181, 0.5931, 0.5582] super().test_output(expected_slice) diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py index c00feb9d8e3f..ff9285c63f16 100644 --- a/utils/check_doc_toc.py +++ b/utils/check_doc_toc.py @@ -43,7 +43,7 @@ def clean_doc_toc(doc_list): new_doc = [] for duplicate_key in duplicates: - titles = list(set(doc["title"] for doc in doc_list if doc["local"] == duplicate_key)) + titles = list({doc["title"] for doc in doc_list if doc["local"] == duplicate_key}) if len(titles) > 1: raise ValueError( f"{duplicate_key} is present several times in the documentation table of content at " diff --git a/utils/check_repo.py b/utils/check_repo.py index 2cdb9af62de9..cfd2964f9dcc 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -219,7 +219,7 @@ def check_model_list(): # Get the models from the directory structure of `src/transformers/models/` models = [model for model in dir(diffusers.models) if not model.startswith("__")] - missing_models = sorted(list(set(_models).difference(models))) + missing_models = sorted(set(_models).difference(models)) if missing_models: raise Exception( f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}." @@ -429,7 +429,7 @@ def get_all_auto_configured_models(): for attr_name in dir(diffusers.models.auto.modeling_flax_auto): if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"): result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name))) - return [cls for cls in result] + return list(result) def ignore_unautoclassed(model_name): diff --git a/utils/overwrite_expected_slice.py b/utils/overwrite_expected_slice.py index 95799f9ca625..7aa66727150a 100644 --- a/utils/overwrite_expected_slice.py +++ b/utils/overwrite_expected_slice.py @@ -67,7 +67,7 @@ def overwrite_file(file, class_name, test_name, correct_line, done_test): def main(correct, fail=None): if fail is not None: with open(fail, "r") as f: - test_failures = set([l.strip() for l in f.readlines()]) + test_failures = {l.strip() for l in f.readlines()} else: test_failures = None diff --git a/utils/stale.py b/utils/stale.py index 36631b65a3ba..12932f31c243 100644 --- a/utils/stale.py +++ b/utils/stale.py @@ -38,7 +38,7 @@ def main(): open_issues = repo.get_issues(state="open") for issue in open_issues: - comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) + comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True) last_comment = comments[0] if len(comments) > 0 else None if ( last_comment is not None