Skip to content

[Stable UnCLIP] Finish Stable UnCLIP #2814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer

Expand Down Expand Up @@ -178,6 +178,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)

def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

# We'll offload the last model manually.
self.final_offload_hook = hook

@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
Expand Down Expand Up @@ -581,6 +606,7 @@ def noise_image_embeddings(

noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)

self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds)

image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
Expand Down Expand Up @@ -884,6 +910,10 @@ def __call__(
# 14. Post-processing
image = self.decode_latents(latents)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()

# 15. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring
from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer

Expand Down Expand Up @@ -180,6 +180,31 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)

def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

# We'll offload the last model manually.
self.final_offload_hook = hook

@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
Expand Down Expand Up @@ -548,6 +573,7 @@ def noise_image_embeddings(

noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)

self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds)

image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
Expand All @@ -571,8 +597,8 @@ def noise_image_embeddings(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 20,
Expand All @@ -597,8 +623,8 @@ def __call__(

Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
used or prompt is initialized to `""`.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
Expand Down Expand Up @@ -674,6 +700,9 @@ def __call__(
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if prompt is None and prompt_embeds is None:
prompt = len(image) * [""] if isinstance(image, list) else ""

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
Expand Down Expand Up @@ -777,6 +806,10 @@ def __call__(
# 9. Post-processing
image = self.decode_latents(latents)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()

# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch
from torch import nn

Expand All @@ -37,6 +39,15 @@ def __init__(
self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.std = nn.Parameter(torch.ones(1, embedding_dim))

def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
return self

def scale(self, embeds):
embeds = (embeds - self.mean) * 1.0 / self.std
return embeds
Expand Down