You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
Currently there are many casts to float32 so that the VAE can work without hitting zero values when in float16 mode. When in float16 mode or when dtype=torch.bfloat16 we should use the VAE in torch.bfloat16 if it is available in our environment. This will reduce VRAM usage by approximately half and in some cases increase performance.
Describe the solution you'd like
# ...somewhere, in a pipeline...ifself.vae.dtype==torch.float16andtorch.cuda.is_bf16_supported():
self.vae.to(dtype=torch.bfloat16)
# Make sure the VAE is in float32 mode, as it overflows in float16.# We don't need to do the upcasting and float32 dance if we have# access to bfloat16, if which case we can just directly use the recast# latents.ifself.vae.dtype==torch.bfloat16:
latents=latents.to(dtype=torch.bfloat16)
elifself.vae.dtype==torch.float16andself.vae.config.force_upcast:
self.upcast_vae()
latents=latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)