Skip to content

VAE should have the ability to use bfloat16 if it is available to the user for LDMs that use a VAE, to save memory #4102

@AmericanPresidentJimmyCarter

Description

Is your feature request related to a problem? Please describe.
Currently there are many casts to float32 so that the VAE can work without hitting zero values when in float16 mode. When in float16 mode or when dtype=torch.bfloat16 we should use the VAE in torch.bfloat16 if it is available in our environment. This will reduce VRAM usage by approximately half and in some cases increase performance.

Describe the solution you'd like

# ...somewhere, in a pipeline...

if self.vae.dtype == torch.float16 and torch.cuda.is_bf16_supported():
  self.vae.to(dtype=torch.bfloat16)

# Make sure the VAE is in float32 mode, as it overflows in float16.
# We don't need to do the upcasting and float32 dance if we have
# access to bfloat16, if which case we can just directly use the recast
# latents.
if self.vae.dtype == torch.bfloat16:
  latents = latents.to(dtype=torch.bfloat16)
elif self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
  self.upcast_vae()
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions