From a72334fe2298fcf65ecfbea2f0a5b6670d20f678 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Mar 2023 15:25:47 +0000 Subject: [PATCH 1/4] up --- .../pipeline_stable_unclip.py | 31 +++++++++++++++++- .../pipeline_stable_unclip_img2img.py | 32 ++++++++++++++++++- .../stable_unclip_image_normalizer.py | 11 +++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index a8ba0b504628..f251bf6ff440 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -22,7 +22,7 @@ from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -178,6 +178,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -884,6 +909,10 @@ def __call__( # 14. Post-processing image = self.decode_latents(latents) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + # 15. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 4a8a4de9580b..5182eb814564 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -24,7 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor, replace_example_docstring +from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -180,6 +180,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -548,6 +573,7 @@ def noise_image_embeddings( noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) @@ -777,6 +803,10 @@ def __call__( # 9. Post-processing image = self.decode_latents(latents) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py index 9c7f190d0505..7362df7e80e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +++ b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch from torch import nn @@ -37,6 +39,15 @@ def __init__( self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim)) + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + def scale(self, embeds): embeds = (embeds - self.mean) * 1.0 / self.std return embeds From eab90f744d088bd7d1719ced4f5bcb2ac38c31cb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Mar 2023 15:32:31 +0000 Subject: [PATCH 2/4] fix more 7 --- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 5182eb814564..e7b1eaaf35c9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -597,8 +597,8 @@ def noise_image_embeddings( @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, @@ -623,8 +623,8 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be + used or prompt is initialized to `""`. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the @@ -700,6 +700,9 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + if prompt is None and prompt_embeds is None: + prompt = "" + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, From 79567d733fa3189fbedaba9b7a548517009972dc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Mar 2023 15:35:42 +0000 Subject: [PATCH 3/4] up --- .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index f251bf6ff440..1341ec2b284b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -606,6 +606,7 @@ def noise_image_embeddings( noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) From 82125d2efb3b0f03237e5d11804cad72c5b87fe3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Mar 2023 15:45:44 +0000 Subject: [PATCH 4/4] finish --- .../stable_diffusion/pipeline_stable_unclip_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index e7b1eaaf35c9..bdebb507a7b5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -701,7 +701,7 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor if prompt is None and prompt_embeds is None: - prompt = "" + prompt = len(image) * [""] if isinstance(image, list) else "" # 1. Check inputs. Raise error if not correct self.check_inputs(