Skip to content

Commit fb94f42

Browse files
committed
add mixin to rdm & restore audioldm2 & fix quality checks
1 parent 26d72b4 commit fb94f42

File tree

5 files changed

+42
-3
lines changed

5 files changed

+42
-3
lines changed

examples/research_projects/rdm/pipeline_rdm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.image_processor import VaeImageProcessor
22+
from diffusers.pipelines.pipeline_utils import EfficiencyMixin
2223
from diffusers.utils import logging
2324
from diffusers.utils.torch_utils import randn_tensor
2425

2526

2627
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2728

2829

29-
class RDMPipeline(DiffusionPipeline):
30+
class RDMPipeline(DiffusionPipeline, EfficiencyMixin):
3031
r"""
3132
Pipeline for text-to-image generation using Retrieval Augmented Diffusion.
3233

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,22 @@ def __init__(
173173
)
174174
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
175175

176+
# Copied from diffusers.pipelines.pipeline_utils.EfficiencyMixin.enable_vae_slicing
177+
def enable_vae_slicing(self):
178+
r"""
179+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
180+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
181+
"""
182+
self.vae.enable_slicing()
183+
184+
# Copied from diffusers.pipelines.pipeline_utils.EfficiencyMixin.disable_vae_slicing
185+
def disable_vae_slicing(self):
186+
r"""
187+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
188+
computing decoding in one step.
189+
"""
190+
self.vae.disable_slicing()
191+
176192
def enable_model_cpu_offload(self, gpu_id=0):
177193
r"""
178194
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ def retrieve_timesteps(
115115

116116

117117
class StableDiffusionPipeline(
118-
DiffusionPipeline, EfficiencyMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
118+
DiffusionPipeline,
119+
EfficiencyMixin,
120+
TextualInversionLoaderMixin,
121+
LoraLoaderMixin,
122+
IPAdapterMixin,
123+
FromSingleFileMixin,
119124
):
120125
r"""
121126
Pipeline for text-to-image generation using Stable Diffusion.

src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def apply_model(self, *args, **kwargs):
4747
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
4848

4949

50-
class StableDiffusionKDiffusionPipeline(DiffusionPipeline, EfficiencyMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
50+
class StableDiffusionKDiffusionPipeline(
51+
DiffusionPipeline, EfficiencyMixin, TextualInversionLoaderMixin, LoraLoaderMixin
52+
):
5153
r"""
5254
Pipeline for text-to-image generation using Stable Diffusion.
5355

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,21 @@ def from_pretrained(cls, *args, **kwargs):
555555
requires_backends(cls, ["torch"])
556556

557557

558+
class EfficiencyMixin(metaclass=DummyObject):
559+
_backends = ["torch"]
560+
561+
def __init__(self, *args, **kwargs):
562+
requires_backends(self, ["torch"])
563+
564+
@classmethod
565+
def from_config(cls, *args, **kwargs):
566+
requires_backends(cls, ["torch"])
567+
568+
@classmethod
569+
def from_pretrained(cls, *args, **kwargs):
570+
requires_backends(cls, ["torch"])
571+
572+
558573
class ImagePipelineOutput(metaclass=DummyObject):
559574
_backends = ["torch"]
560575

0 commit comments

Comments
 (0)