diff --git a/docs/source/en/using-diffusers/loading.md b/docs/source/en/using-diffusers/loading.md
index d35f3fb548e9..9d5534154fc8 100644
--- a/docs/source/en/using-diffusers/loading.md
+++ b/docs/source/en/using-diffusers/loading.md
@@ -179,6 +179,210 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
)
```
+### Switch loaded pippelines
+
+There are many diffuser pipelines that use the same pre-trained model as [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`], but they implement specific features to help you achieve better generation results. This guide will show you how to use the `from_pipe` API to create multiple pipelines without increasing memory usage. By using this approach, you can easily switch between pipelines to use different features.
+
+Let's take an example where we first create a [`StableDiffusionPipeline`] and then reuse the already loaded model components to create a [`StableDiffusionSAGPipeline`] to enhance generation quality.
+
+we will generate an image of a bear eating pizza using Stable Diffusion with the IP-Adapter
+
+```python
+from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
+import torch
+import gc
+from diffusers.utils import load_image
+from accelerate.utils import compute_module_sizes
+
+base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
+num_inference_steps = 50
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
+prompt="bear eats pizza"
+negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
+
+pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
+pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+pipe_sd.set_ip_adapter_scale(0.6)
+pipe_sd.to("cuda")
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+out_sd = pipe_sd(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ ip_adapter_image=image,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+).images[0]
+```
+
+let’s take a look at the image and also print out the memory used
+
+
+

+
+
+```python
+def bytes_to_giga_bytes(bytes):
+ return bytes / 1024 / 1024 / 1024
+print(
+ f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
+)
+```
+
+```bash
+Max memory allocated: 4.406213283538818 GB
+```
+
+Now, we can use `from_pipe` to switch to the SAG pipeline.
+
+```python
+pipe_sag = StableDiffusionSAGPipeline.from_pipe(
+ pipe_sd,
+)
+```
+
+It already has IP-Adapter loaded so that you can pass the same bear image as `ip_adapter_image`
+
+```python
+generator = torch.Generator(device="cpu").manual_seed(33)
+out_sag = pipe_sag(
+ prompt = prompt,
+ negative_prompt=negative_prompt,
+ ip_adapter_image=image,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ guidance_scale=1.0,
+ sag_scale=0.75).images[0]
+```
+
+You can see a pretty nice improvement in the output
+
+
+

+
+
+Now we have both `stableDiffusionPipeline` and `StableDiffusionSAGPipeline` co-existing with the same loaded model components; You can use them interchangeably without additional memory.
+
+```
+print(
+ f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
+)
+```
+
+```bash
+Max memory allocated: 4.406213283538818 GB
+```
+
+Let's unload the IP adapter from the SAG pipeline. It's important to note that methods like `load_ip_adapter` and `unload_ip_adapter` modify the state of the model components. Therefore, when you use these methods on one pipeline, it will affect all other pipelines that share the same model components.
+
+```bash
+pipe_sag.unload_ip_adapter()
+```
+
+If you try to use the Stable Diffusion pipeline with IP adapter again, it will fail
+
+```bash
+generator = torch.Generator(device="cpu").manual_seed(33)
+out_sd = pipe_sd(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ ip_adapter_image=image,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+).images[0]
+```
+
+```bash
+AttributeError: 'NoneType' object has no attribute 'image_projection_layers'
+```
+
+Please note that the pipeline methods may not function properly on a new pipeline created using the `from_pipe` method. For instance, the `enable_model_cpu_offload` method installs hooks to the model components based on a unique offloading sequence for each pipeline. Therefore, if the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
+
+To ensure proper functionality, we recommend re-applying the pipeline methods on the new pipeline created using the `from_pipe` method.
+
+You can also add or subtract model components when you create new pipelines. Let's now create a AnimateDiff pipeline with an additional `MotionAdapter` module
+
+```bash
+from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
+
+pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
+pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
+# load ip_adapter again and load lora weights
+pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
+pipe_animate.to("cuda")
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
+out = pipe_animate(
+ prompt= prompt,
+ num_frames=16,
+ num_inference_steps=num_inference_steps,
+ ip_adapter_image = image,
+ generator=generator,
+).frames[0]
+export_to_gif(out, "out_animate.gif")
+```
+
+

+
+
+
+When creating multiple pipelines using the `from_pipe` method, it is important to note that the memory requirement will be determined by the pipeline with the highest memory usage. This means that regardless of the number of pipelines you create, the total memory requirement will always be the same as the highest memory requirement among the pipelines.
+
+For example, we have created three pipelines - `stableDiffusionPipeline`, `StableDiffusionSAGPipeline`, and `AnimateDiffPipeline` - and the `AnimateDiffPipeline` has the highest memory requirement, then the total memory usage will be based on the memory requirement of the `AnimateDiffPipeline`.
+
+Therefore, creating additional pipelines will not add up to the total memory requirement. Each pipeline can be used interchangeably without any additional memory overhead.
+
+
+Did you know that you can use `from_pipe` with a community pipeline? Let me show you an example of using long negative prompt and prompt weighting!
+
+```bash
+pipe_lpw = DiffusionPipeline.from_pipe(
+ pipe_sd,
+ custom_pipeline="lpw_stable_diffusion",
+).to("cuda")
+
+prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
+neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
+generator = torch.Generator(device="cpu").manual_seed(33)
+out_lpw = pipe_lpw.text2img(
+ prompt,
+ negative_prompt=neg_prompt,
+ width=512,height=512,
+ max_embeddings_multiples=3,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ ).images[0]
+```
+
+
+

+
+
+let’s run StableDiffusionPipeline with the same inputs to compare: the result from the long prompt weighting pipeline is more aligned with the text prompt.
+
+```
+generator = torch.Generator(device="cpu").manual_seed(33)
+out_sd = pipe_sd(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=generator,
+ num_inference_steps=num_inference_steps,
+).images[0]
+out_sd
+```
+
+

+
+
+
+You can easily switch between different pipelines using the `from_pipe` method, similar to turning on and off a feature on your pipeline. To switch between tasks, you can use the `from_pipe` method with `AutoPipeline`, which automatically identifies the pipeline class based on the task. You can find more information about this feature at the [AutoPipe Guide](https://huggingface.co/docs/diffusers/tutorials/autopipeline).
+
+
## Checkpoint variants
A checkpoint variant is usually a checkpoint whose weights are:
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index 78d93bfb7081..3a47105e574d 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -439,7 +439,9 @@ class StableDiffusionLongPromptWeightingPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
+ model_cpu_offload_seq = "text_encoder-->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index ab2eac4c9a9a..88c0b967c099 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -17,7 +17,7 @@
import torch.nn as nn
import torch.utils.checkpoint
-from ...configuration_utils import ConfigMixin, register_to_config
+from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
@@ -393,8 +393,11 @@ def from_unet2d(
):
has_motion_adapter = motion_adapter is not None
+ if has_motion_adapter:
+ motion_adapter.to(device=unet.device)
+
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
- config = unet.config
+ config = dict(unet.config)
config["_class_name"] = cls.__name__
down_blocks = []
@@ -427,6 +430,7 @@ def from_unet2d(
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]
+ config = FrozenDict(config)
model = cls.from_config(config)
if not load_weights:
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index c8f8fd1d2098..12347227a15e 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -131,7 +131,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
- unet: UNet2DConditionModel,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 30c17eec119d..11b2c549096b 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -292,6 +292,39 @@ def get_class_obj_and_candidates(
return class_obj, class_candidates
+def _get_custom_pipeline_class(
+ custom_pipeline,
+ repo_id=None,
+ hub_revision=None,
+ class_name=None,
+ cache_dir=None,
+ revision=None,
+):
+ if custom_pipeline.endswith(".py"):
+ path = Path(custom_pipeline)
+ # decompose into folder & file
+ file_name = path.name
+ custom_pipeline = path.parent.absolute()
+ elif repo_id is not None:
+ file_name = f"{custom_pipeline}.py"
+ custom_pipeline = repo_id
+ else:
+ file_name = CUSTOM_PIPELINE_FILE_NAME
+
+ if repo_id is not None and hub_revision is not None:
+ # if we load the pipeline code from the Hub
+ # make sure to overwrite the `revision`
+ revision = hub_revision
+
+ return get_class_from_dynamic_module(
+ custom_pipeline,
+ module_file=file_name,
+ class_name=class_name,
+ cache_dir=cache_dir,
+ revision=revision,
+ )
+
+
def _get_pipeline_class(
class_obj,
config=None,
@@ -304,25 +337,10 @@ def _get_pipeline_class(
revision=None,
):
if custom_pipeline is not None:
- if custom_pipeline.endswith(".py"):
- path = Path(custom_pipeline)
- # decompose into folder & file
- file_name = path.name
- custom_pipeline = path.parent.absolute()
- elif repo_id is not None:
- file_name = f"{custom_pipeline}.py"
- custom_pipeline = repo_id
- else:
- file_name = CUSTOM_PIPELINE_FILE_NAME
-
- if repo_id is not None and hub_revision is not None:
- # if we load the pipeline code from the Hub
- # make sure to overwrite the `revision`
- revision = hub_revision
-
- return get_class_from_dynamic_module(
+ return _get_custom_pipeline_class(
custom_pipeline,
- module_file=file_name,
+ repo_id=repo_id,
+ hub_revision=hub_revision,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index f59c25c19183..a98d736aa557 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -21,7 +21,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
import numpy as np
import PIL.Image
@@ -43,7 +43,7 @@
from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
@@ -72,6 +72,7 @@
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
_fetch_class_library_tuple,
+ _get_custom_pipeline_class,
_get_pipeline_class,
_unwrap_model,
is_safetensors_compatible,
@@ -1475,6 +1476,18 @@ def _get_signature_keys(cls, obj):
return expected_modules, optional_parameters
+ @classmethod
+ def _get_signature_types(cls):
+ signature_types = {}
+ for k, v in inspect.signature(cls.__init__).parameters.items():
+ if inspect.isclass(v.annotation):
+ signature_types[k] = (v.annotation,)
+ elif get_origin(v.annotation) == Union:
+ signature_types[k] = get_args(v.annotation)
+ else:
+ logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
+ return signature_types
+
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -1653,6 +1666,128 @@ def set_attention_slice(self, slice_size: Optional[int]):
for module in modules:
module.set_attention_slice(slice_size)
+ @classmethod
+ def from_pipe(cls, pipeline, **kwargs):
+ r"""
+ Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing pipeline components without reallocating additional memory.
+
+ Arguments:
+ pipeline (`DiffusionPipeline`):
+ The pipeline from which to create a new pipeline.
+
+ Returns:
+ `DiffusionPipeline`:
+ A new pipeline with the same weights and configurations as `pipeline`.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
+ ```
+ """
+
+ original_config = dict(pipeline.config)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ # derive the pipeline class to instantiate
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
+ custom_revision = kwargs.pop("custom_revision", None)
+
+ if custom_pipeline is not None:
+ pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
+ else:
+ pipeline_class = cls
+
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
+ # e.g. `image_encoder` for StableDiffusionPipeline
+ parameters = inspect.signature(cls.__init__).parameters
+ true_optional_modules = set(
+ {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
+ )
+
+ # get the class of each component based on its type hint
+ # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
+ component_types = pipeline_class._get_signature_types()
+
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
+ # allow users pass modules in `kwargs` to override the original pipeline's components
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ original_class_obj = {}
+ for name, component in pipeline.components.items():
+ if name in expected_modules and name not in passed_class_obj:
+ # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
+ if (
+ not isinstance(component, ModelMixin)
+ or type(component) in component_types[name]
+ or (component is None and name in cls._optional_components)
+ ):
+ original_class_obj[name] = component
+ else:
+ logger.warn(
+ f"component {name} is not switched over to new pipeline because type does not match the expected."
+ f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
+ f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
+ )
+
+ # allow users pass optional kwargs to override the original pipelines config attribute
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+ original_pipe_kwargs = {
+ k: original_config[k]
+ for k in original_config.keys()
+ if k in optional_kwargs and k not in passed_pipe_kwargs
+ }
+
+ # config attribute that were not expected by pipeline is stored as its private attribute
+ # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
+ # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
+ additional_pipe_kwargs = [
+ k[1:]
+ for k in original_config.keys()
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
+ ]
+ for k in additional_pipe_kwargs:
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
+
+ pipeline_kwargs = {
+ **passed_class_obj,
+ **original_class_obj,
+ **passed_pipe_kwargs,
+ **original_pipe_kwargs,
+ **kwargs,
+ }
+
+ # store unused config as private attribute in the new pipeline
+ unused_original_config = {
+ f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
+ }
+
+ missing_modules = (
+ set(expected_modules)
+ - set(pipeline._optional_components)
+ - set(pipeline_kwargs.keys())
+ - set(true_optional_modules)
+ )
+
+ if len(missing_modules) > 0:
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
+ )
+
+ new_pipeline = pipeline_class(**pipeline_kwargs)
+ if pretrained_model_name_or_path is not None:
+ new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ new_pipeline.register_to_config(**unused_original_config)
+
+ if torch_dtype is not None:
+ new_pipeline.to(dtype=torch_dtype)
+
+ return new_pipeline
+
class StableDiffusionMixin:
r"""
diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
index 03c80b46b806..8e43676494a5 100644
--- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
@@ -902,6 +902,7 @@ def __call__(
if attn_res is None:
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
self.attention_store = AttentionStore(attn_res)
+ original_attn_proc = self.unet.attn_processors
self.register_attention_control()
# default config for step size from original repo
@@ -1016,6 +1017,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
+ # make sure to set the original attention processors back
+ self.unet.set_attn_processor(original_attn_proc)
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
index 7dfefd94da47..2e7a1fa41b58 100644
--- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
@@ -750,6 +750,7 @@ def __call__(
)
# 7. Denoising loop
+ original_attn_proc = self.unet.attn_processors
store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -848,6 +849,8 @@ def get_map_size(module, input, output):
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks()
+ # make sure to set the original attention processors back
+ self.unet.set_attn_processor(original_attn_proc)
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index d45408e9543a..5ac211eef80f 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -329,13 +329,6 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- processor = (
- CrossFrameAttnProcessor2_0(batch_size=2)
- if hasattr(F, "scaled_dot_product_attention")
- else CrossFrameAttnProcessor(batch_size=2)
- )
- self.unet.set_attn_processor(processor)
-
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -616,6 +609,15 @@ def __call__(
assert num_videos_per_prompt == 1
+ # set the processor
+ original_attn_proc = self.unet.attn_processors
+ processor = (
+ CrossFrameAttnProcessor2_0(batch_size=2)
+ if hasattr(F, "scaled_dot_product_attention")
+ else CrossFrameAttnProcessor(batch_size=2)
+ )
+ self.unet.set_attn_processor(processor)
+
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
@@ -739,6 +741,8 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()
+ # make sure to set the original attention processors back
+ self.unet.set_attn_processor(original_attn_proc)
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index eaa2760363a9..07d7e92e11d9 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -411,14 +411,6 @@ def __init__(
else:
self.watermark = None
- processor = (
- CrossFrameAttnProcessor2_0(batch_size=2)
- if hasattr(F, "scaled_dot_product_attention")
- else CrossFrameAttnProcessor(batch_size=2)
- )
-
- self.unet.set_attn_processor(processor)
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1084,6 +1076,15 @@ def __call__(
assert num_videos_per_prompt == 1
+ # set the processor
+ original_attn_proc = self.unet.attn_processors
+ processor = (
+ CrossFrameAttnProcessor2_0(batch_size=2)
+ if hasattr(F, "scaled_dot_product_attention")
+ else CrossFrameAttnProcessor(batch_size=2)
+ )
+ self.unet.set_attn_processor(processor)
+
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
@@ -1305,9 +1306,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload last model to CPU manually for max memory savings
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
+ self.maybe_free_model_hooks()
+ # make sure to set the original attention processors back
+ self.unet.set_attn_processor(original_attn_proc)
if not return_dict:
return (image,)
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index 802da19ba654..c61a1ee45b89 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -18,7 +18,12 @@
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+)
def to_np(tensor):
@@ -29,7 +34,7 @@ def to_np(tensor):
class AnimateDiffPipelineFastTests(
- IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase
+ IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffPipeline
params = TEXT_TO_IMAGE_PARAMS
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
index 81317b44a3c9..cabfd29e0d32 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -18,7 +18,7 @@
from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
def to_np(tensor):
@@ -28,7 +28,9 @@ def to_np(tensor):
return tensor
-class AnimateDiffVideoToVideoPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
+class AnimateDiffVideoToVideoPipelineFastTests(
+ IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
pipeline_class = AnimateDiffVideoToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS
diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py
index 16ca4bc0957d..3a89452585fb 100644
--- a/tests/pipelines/pia/test_pia.py
+++ b/tests/pipelines/pia/test_pia.py
@@ -17,7 +17,7 @@
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import floats_tensor, torch_device
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
def to_np(tensor):
@@ -27,7 +27,7 @@ def to_np(tensor):
return tensor
-class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
+class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
pipeline_class = PIAPipeline
params = frozenset(
[
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
index 5e296f945ded..22548cd0eff2 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
@@ -35,7 +35,12 @@
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import (
+ PipelineFromPipeTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
torch.backends.cuda.matmul.allow_tf32 = False
@@ -43,7 +48,11 @@
@skip_mps
class StableDiffusionAttendAndExcitePipelineFastTests(
- PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
):
pipeline_class = StableDiffusionAttendAndExcitePipeline
test_attention_slicing = False
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
index e9a9f79aa989..1cb03ddd96d7 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
@@ -43,13 +43,15 @@
)
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
-class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
+class StableDiffusionDiffEditPipelineFastTests(
+ PipelineLatentTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
pipeline_class = StableDiffusionDiffEditPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"}
diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
index cb951d5dd833..36bccaac9d93 100644
--- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
+++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
@@ -46,7 +46,7 @@
)
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
+from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism()
@@ -337,7 +337,9 @@ def test_adapter_lcm_custom_timesteps(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
+class StableDiffusionFullAdapterPipelineFastTests(
+ AdapterTests, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
+):
def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim)
diff --git a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
index 3b8383b12793..405809aee19e 100644
--- a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
+++ b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
@@ -33,14 +33,23 @@
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import (
+ PipelineFromPipeTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
enable_full_determinism()
class GligenPipelineFastTests(
- PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
):
pipeline_class = StableDiffusionGLIGENPipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
index 111e8d8df491..f9f8b044a916 100644
--- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
+++ b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
@@ -42,14 +42,23 @@
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import (
+ PipelineFromPipeTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
enable_full_determinism()
class GligenTextImagePipelineFastTests(
- PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
):
pipeline_class = StableDiffusionGLIGENTextImagePipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
index f275c59c7ca5..35a654542023 100644
--- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
+++ b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
@@ -32,7 +32,12 @@
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
enable_full_determinism()
@@ -40,7 +45,11 @@
@skip_mps
class StableDiffusionPanoramaPipelineFastTests(
- IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
+ IPAdapterTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
):
pipeline_class = StableDiffusionPanoramaPipeline
params = TEXT_TO_IMAGE_PARAMS
diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
index f210f3e75ad3..1d4e66bd65f0 100644
--- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
+++ b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
@@ -32,14 +32,23 @@
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
enable_full_determinism()
class StableDiffusionSAGPipelineFastTests(
- IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
+ IPAdapterTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
):
pipeline_class = StableDiffusionSAGPipeline
params = TEXT_TO_IMAGE_PARAMS
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 90ff834ed9a1..411f6a3b8092 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -26,10 +26,12 @@
DDIMScheduler,
DiffusionPipeline,
StableDiffusionPipeline,
+ StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
+from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
@@ -511,6 +513,186 @@ def test_multi_vae(self):
assert out_vae_np.shape == out_np.shape
+@require_torch
+class PipelineFromPipeTesterMixin:
+ @property
+ def original_pipeline_class(self):
+ if "xl" in self.pipeline_class.__name__.lower():
+ original_pipeline_class = StableDiffusionXLPipeline
+ else:
+ original_pipeline_class = StableDiffusionPipeline
+
+ return original_pipeline_class
+
+ def get_dummy_inputs_pipe(self, device, seed=0):
+ inputs = self.get_dummy_inputs(device, seed=seed)
+ inputs["output_type"] = "np"
+ inputs["return_dict"] = False
+ return inputs
+
+ def get_dummy_inputs_for_pipe_original(self, device, seed=0):
+ inputs = {}
+ for k, v in self.get_dummy_inputs_pipe(device, seed=seed).items():
+ if k in set(inspect.signature(self.original_pipeline_class.__call__).parameters.keys()):
+ inputs[k] = v
+ return inputs
+
+ def test_from_pipe_consistent_config(self):
+ if self.original_pipeline_class == StableDiffusionPipeline:
+ original_repo = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ original_kwargs = {"requires_safety_checker": False}
+ elif self.original_pipeline_class == StableDiffusionXLPipeline:
+ original_repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+ original_kwargs = {"requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False}
+ else:
+ raise ValueError(
+ "original_pipeline_class must be either StableDiffusionPipeline or StableDiffusionXLPipeline"
+ )
+
+ # create original_pipeline_class(sd/sdxl)
+ pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
+
+ # original_pipeline_class(sd/sdxl) -> pipeline_class
+ pipe_components = self.get_dummy_components()
+ pipe_additional_components = {}
+ for name, component in pipe_components.items():
+ if name not in pipe_original.components:
+ pipe_additional_components[name] = component
+
+ pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
+
+ # pipeline_class -> original_pipeline_class(sd/sdxl)
+ original_pipe_additional_components = {}
+ for name, component in pipe_original.components.items():
+ if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
+ original_pipe_additional_components[name] = component
+
+ pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
+
+ # compare the config
+ original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
+ original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
+ assert original_config_2 == original_config
+
+ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
+ components = self.get_dummy_components()
+ original_expected_modules, _ = self.original_pipeline_class._get_signature_keys(self.original_pipeline_class)
+
+ # pipeline components that are also expected to be in the original pipeline
+ original_pipe_components = {}
+ # additional components that are not in the pipeline, but expected in the original pipeline
+ original_pipe_additional_components = {}
+ # additional components that are in the pipeline, but not expected in the original pipeline
+ current_pipe_additional_components = {}
+
+ for name, component in components.items():
+ if name in original_expected_modules:
+ original_pipe_components[name] = component
+ else:
+ current_pipe_additional_components[name] = component
+ for name in original_expected_modules:
+ if name not in original_pipe_components:
+ if name in self.original_pipeline_class._optional_components:
+ original_pipe_additional_components[name] = None
+ else:
+ raise ValueError(f"missing required module for {self.original_pipeline_class.__class__}: {name}")
+
+ pipe_original = self.original_pipeline_class(**original_pipe_components, **original_pipe_additional_components)
+ for component in pipe_original.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_original.to(torch_device)
+ pipe_original.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs_for_pipe_original(torch_device)
+ output_original = pipe_original(**inputs)[0]
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs_pipe(torch_device)
+ output = pipe(**inputs)[0]
+
+ pipe_from_original = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
+ pipe_from_original.to(torch_device)
+ pipe_from_original.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs_pipe(torch_device)
+ output_from_original = pipe_from_original(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_from_original)).max()
+ self.assertLess(
+ max_diff,
+ expected_max_diff,
+ "The outputs of the pipelines created with `from_pipe` and `__init__` are different.",
+ )
+
+ inputs = self.get_dummy_inputs_for_pipe_original(torch_device)
+ output_original_2 = pipe_original(**inputs)[0]
+
+ max_diff = np.abs(to_np(output_original) - to_np(output_original_2)).max()
+ self.assertLess(max_diff, expected_max_diff, "`from_pipe` should not change the output of original pipeline.")
+
+ for component in pipe_original.components.values():
+ if hasattr(component, "attn_processors"):
+ assert all(
+ type(proc) == AttnProcessor for proc in component.attn_processors.values()
+ ), "`from_pipe` changed the attention processor in original pipeline."
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
+ reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
+ )
+ def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1e-3):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs_pipe(torch_device)
+ output = pipe(**inputs)[0]
+
+ original_expected_modules, _ = self.original_pipeline_class._get_signature_keys(self.original_pipeline_class)
+ # pipeline components that are also expected to be in the original pipeline
+ original_pipe_components = {}
+ # additional components that are not in the pipeline, but expected in the original pipeline
+ original_pipe_additional_components = {}
+ # additional components that are in the pipeline, but not expected in the original pipeline
+ current_pipe_additional_components = {}
+ for name, component in components.items():
+ if name in original_expected_modules:
+ original_pipe_components[name] = component
+ else:
+ current_pipe_additional_components[name] = component
+ for name in original_expected_modules:
+ if name not in original_pipe_components:
+ if name in self.original_pipeline_class._optional_components:
+ original_pipe_additional_components[name] = None
+ else:
+ raise ValueError(f"missing required module for {self.original_pipeline_class.__class__}: {name}")
+
+ pipe_original = self.original_pipeline_class(**original_pipe_components, **original_pipe_additional_components)
+ for component in pipe_original.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_original.set_progress_bar_config(disable=None)
+ pipe_from_original = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
+ pipe_from_original.enable_model_cpu_offload()
+ pipe_from_original.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs_pipe(torch_device)
+ output_from_original = pipe_from_original(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_from_original)).max()
+ self.assertLess(
+ max_diff,
+ expected_max_diff,
+ "The outputs of the pipelines created with `from_pipe` and `__init__` are different.",
+ )
+
+
@require_torch
class PipelineKarrasSchedulerTesterMixin:
"""
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
index eb90f52b72ec..bcde20a36c34 100644
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
@@ -30,7 +30,7 @@
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin
enable_full_determinism()
@@ -43,7 +43,7 @@ def to_np(tensor):
return tensor
-class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
pipeline_class = TextToVideoZeroSDXLPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS