Skip to content

Commit dbcb15c

Browse files
[Stable UnCLIP] Finish Stable UnCLIP (#2814)
* up * fix more 7 * up * finish
1 parent c4892f1 commit dbcb15c

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
2323
from ...models.embeddings import get_timestep_embedding
2424
from ...schedulers import KarrasDiffusionSchedulers
25-
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
25+
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2727
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
2828

@@ -178,6 +178,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
178178
if cpu_offloaded_model is not None:
179179
cpu_offload(cpu_offloaded_model, device)
180180

181+
def enable_model_cpu_offload(self, gpu_id=0):
182+
r"""
183+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
184+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
185+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
186+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
187+
"""
188+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
189+
from accelerate import cpu_offload_with_hook
190+
else:
191+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
192+
193+
device = torch.device(f"cuda:{gpu_id}")
194+
195+
if self.device.type != "cpu":
196+
self.to("cpu", silence_dtype_warnings=True)
197+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
198+
199+
hook = None
200+
for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]:
201+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
202+
203+
# We'll offload the last model manually.
204+
self.final_offload_hook = hook
205+
181206
@property
182207
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
183208
def _execution_device(self):
@@ -581,6 +606,7 @@ def noise_image_embeddings(
581606

582607
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
583608

609+
self.image_normalizer.to(image_embeds.device)
584610
image_embeds = self.image_normalizer.scale(image_embeds)
585611

586612
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
@@ -884,6 +910,10 @@ def __call__(
884910
# 14. Post-processing
885911
image = self.decode_latents(latents)
886912

913+
# Offload last model to CPU
914+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
915+
self.final_offload_hook.offload()
916+
887917
# 15. Convert to PIL
888918
if output_type == "pil":
889919
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...models import AutoencoderKL, UNet2DConditionModel
2525
from ...models.embeddings import get_timestep_embedding
2626
from ...schedulers import KarrasDiffusionSchedulers
27-
from ...utils import logging, randn_tensor, replace_example_docstring
27+
from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
2828
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2929
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
3030

@@ -180,6 +180,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
180180
if cpu_offloaded_model is not None:
181181
cpu_offload(cpu_offloaded_model, device)
182182

183+
def enable_model_cpu_offload(self, gpu_id=0):
184+
r"""
185+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
186+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
187+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
188+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
189+
"""
190+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
191+
from accelerate import cpu_offload_with_hook
192+
else:
193+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
194+
195+
device = torch.device(f"cuda:{gpu_id}")
196+
197+
if self.device.type != "cpu":
198+
self.to("cpu", silence_dtype_warnings=True)
199+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
200+
201+
hook = None
202+
for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
203+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
204+
205+
# We'll offload the last model manually.
206+
self.final_offload_hook = hook
207+
183208
@property
184209
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
185210
def _execution_device(self):
@@ -548,6 +573,7 @@ def noise_image_embeddings(
548573

549574
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
550575

576+
self.image_normalizer.to(image_embeds.device)
551577
image_embeds = self.image_normalizer.scale(image_embeds)
552578

553579
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
@@ -571,8 +597,8 @@ def noise_image_embeddings(
571597
@replace_example_docstring(EXAMPLE_DOC_STRING)
572598
def __call__(
573599
self,
574-
prompt: Union[str, List[str]] = None,
575600
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
601+
prompt: Union[str, List[str]] = None,
576602
height: Optional[int] = None,
577603
width: Optional[int] = None,
578604
num_inference_steps: int = 20,
@@ -597,8 +623,8 @@ def __call__(
597623
598624
Args:
599625
prompt (`str` or `List[str]`, *optional*):
600-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
601-
instead.
626+
The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
627+
used or prompt is initialized to `""`.
602628
image (`torch.FloatTensor` or `PIL.Image.Image`):
603629
`Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
604630
the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
@@ -674,6 +700,9 @@ def __call__(
674700
height = height or self.unet.config.sample_size * self.vae_scale_factor
675701
width = width or self.unet.config.sample_size * self.vae_scale_factor
676702

703+
if prompt is None and prompt_embeds is None:
704+
prompt = len(image) * [""] if isinstance(image, list) else ""
705+
677706
# 1. Check inputs. Raise error if not correct
678707
self.check_inputs(
679708
prompt=prompt,
@@ -777,6 +806,10 @@ def __call__(
777806
# 9. Post-processing
778807
image = self.decode_latents(latents)
779808

809+
# Offload last model to CPU
810+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
811+
self.final_offload_hook.offload()
812+
780813
# 10. Convert to PIL
781814
if output_type == "pil":
782815
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional, Union
16+
1517
import torch
1618
from torch import nn
1719

@@ -37,6 +39,15 @@ def __init__(
3739
self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
3840
self.std = nn.Parameter(torch.ones(1, embedding_dim))
3941

42+
def to(
43+
self,
44+
torch_device: Optional[Union[str, torch.device]] = None,
45+
torch_dtype: Optional[torch.dtype] = None,
46+
):
47+
self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
48+
self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
49+
return self
50+
4051
def scale(self, embeds):
4152
embeds = (embeds - self.mean) * 1.0 / self.std
4253
return embeds

0 commit comments

Comments
 (0)