24
24
from ...models import AutoencoderKL , UNet2DConditionModel
25
25
from ...models .embeddings import get_timestep_embedding
26
26
from ...schedulers import KarrasDiffusionSchedulers
27
- from ...utils import logging , randn_tensor , replace_example_docstring
27
+ from ...utils import is_accelerate_version , logging , randn_tensor , replace_example_docstring
28
28
from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
29
29
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
30
30
@@ -180,6 +180,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
180
180
if cpu_offloaded_model is not None :
181
181
cpu_offload (cpu_offloaded_model , device )
182
182
183
+ def enable_model_cpu_offload (self , gpu_id = 0 ):
184
+ r"""
185
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
186
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
187
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
188
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
189
+ """
190
+ if is_accelerate_available () and is_accelerate_version (">=" , "0.17.0.dev0" ):
191
+ from accelerate import cpu_offload_with_hook
192
+ else :
193
+ raise ImportError ("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." )
194
+
195
+ device = torch .device (f"cuda:{ gpu_id } " )
196
+
197
+ if self .device .type != "cpu" :
198
+ self .to ("cpu" , silence_dtype_warnings = True )
199
+ torch .cuda .empty_cache () # otherwise we don't see the memory savings (but they probably exist)
200
+
201
+ hook = None
202
+ for cpu_offloaded_model in [self .text_encoder , self .image_encoder , self .unet , self .vae ]:
203
+ _ , hook = cpu_offload_with_hook (cpu_offloaded_model , device , prev_module_hook = hook )
204
+
205
+ # We'll offload the last model manually.
206
+ self .final_offload_hook = hook
207
+
183
208
@property
184
209
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
185
210
def _execution_device (self ):
@@ -548,6 +573,7 @@ def noise_image_embeddings(
548
573
549
574
noise_level = torch .tensor ([noise_level ] * image_embeds .shape [0 ], device = image_embeds .device )
550
575
576
+ self .image_normalizer .to (image_embeds .device )
551
577
image_embeds = self .image_normalizer .scale (image_embeds )
552
578
553
579
image_embeds = self .image_noising_scheduler .add_noise (image_embeds , timesteps = noise_level , noise = noise )
@@ -571,8 +597,8 @@ def noise_image_embeddings(
571
597
@replace_example_docstring (EXAMPLE_DOC_STRING )
572
598
def __call__ (
573
599
self ,
574
- prompt : Union [str , List [str ]] = None ,
575
600
image : Union [torch .FloatTensor , PIL .Image .Image ] = None ,
601
+ prompt : Union [str , List [str ]] = None ,
576
602
height : Optional [int ] = None ,
577
603
width : Optional [int ] = None ,
578
604
num_inference_steps : int = 20 ,
@@ -597,8 +623,8 @@ def __call__(
597
623
598
624
Args:
599
625
prompt (`str` or `List[str]`, *optional*):
600
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
601
- instead .
626
+ The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
627
+ used or prompt is initialized to `""` .
602
628
image (`torch.FloatTensor` or `PIL.Image.Image`):
603
629
`Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
604
630
the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
@@ -674,6 +700,9 @@ def __call__(
674
700
height = height or self .unet .config .sample_size * self .vae_scale_factor
675
701
width = width or self .unet .config .sample_size * self .vae_scale_factor
676
702
703
+ if prompt is None and prompt_embeds is None :
704
+ prompt = len (image ) * ["" ] if isinstance (image , list ) else ""
705
+
677
706
# 1. Check inputs. Raise error if not correct
678
707
self .check_inputs (
679
708
prompt = prompt ,
@@ -777,6 +806,10 @@ def __call__(
777
806
# 9. Post-processing
778
807
image = self .decode_latents (latents )
779
808
809
+ # Offload last model to CPU
810
+ if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
811
+ self .final_offload_hook .offload ()
812
+
780
813
# 10. Convert to PIL
781
814
if output_type == "pil" :
782
815
image = self .numpy_to_pil (image )
0 commit comments