Skip to content

Stable Cascade, error when generating images with arbitrary resolutions and bfloat16. #7355

@lawfordp2017

Description

@lawfordp2017

Describe the bug

When I create an image of resolution 904x512 with a model using bfloat16 quantization I get an error stating that float32 input and bfloat16 bias are incompatible.

See the associated pull request for the line that seems to be problematic.

Reproduction

import torch
from diffusers import (
    StableCascadeDecoderPipeline,
    StableCascadePriorPipeline,
    StableCascadeUNet,
)

prior_unet = StableCascadeUNet.from_single_file(
    "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors",
    torch_dtype=torch.bfloat16
)
decoder_unet = StableCascadeUNet.from_single_file(
    "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
    torch_dtype=torch.bfloat16
)

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
prior.to("cuda")
decoder.to("cuda")

prompt = "an image of a shiba inu, donning a spacesuit and helmet, windy_tales"
negative_prompt = ""

#prior.enable_model_cpu_offload()
prior_output = prior(
    prompt=prompt,
    # height=1024,
    # width=1024,
    height=512,
    width=904,
    negative_prompt=negative_prompt,
    guidance_scale=4.0,
    num_images_per_prompt=1,
    num_inference_steps=20
)

#decoder.enable_model_cpu_offload()
decoder_output = decoder(
    image_embeddings=prior_output.image_embeddings,
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10
).images[0]
decoder_output.save("cascade-single-file.png")

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 18
      5 prior_output = prior(
      6     prompt=prompt,
      7     # height=1024,
   (...)
     14     num_inference_steps=20
     15 )
     17 #decoder.enable_model_cpu_offload()
---> 18 decoder_output = decoder(
     19     image_embeddings=prior_output.image_embeddings,
     20     prompt=prompt,
     21     negative_prompt=negative_prompt,
     22     guidance_scale=0.0,
     23     output_type="pil",
     24     num_inference_steps=10
     25 ).images[0]
     26 decoder_output.save("cascade-single-file.png")

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py:429, in StableCascadeDecoderPipeline.__call__(self, image_embeddings, prompt, num_inference_steps, guidance_scale, negative_prompt, prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled, num_images_per_prompt, generator, latents, output_type, return_dict, callback_on_step_end, callback_on_step_end_tensor_inputs)
    426 timestep_ratio = t.expand(latents.size(0)).to(dtype)
    428 # 7. Denoise latents
--> 429 predicted_latents = self.decoder(
    430     sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
    431     timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
    432     clip_text_pooled=prompt_embeds_pooled,
    433     effnet=effnet,
    434     return_dict=False,
    435 )[0]
    437 # 8. Check for classifier free guidance and apply it
    438 if self.do_classifier_free_guidance:

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/diffusers/models/unets/unet_stable_cascade.py:607, in StableCascadeUNet.forward(self, sample, timestep_ratio, clip_text_pooled, clip_text, clip_img, effnet, pixels, sca, crp, return_dict)
    603     x = x + nn.functional.interpolate(
    604         self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
    605     )
    606 level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
--> 607 x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
    608 sample = self.clf(x)
    610 if not return_dict:

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/diffusers/models/unets/unet_stable_cascade.py:555, in StableCascadeUNet._up_decode(self, level_outputs, r_embed, clip)
    551                                 x = torch.nn.functional.interpolate(
    552                                     x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
    553                                 )
    554 #                                x = x.type(orig_type)
--> 555                             x = block(x, skip)
    556                         elif isinstance(block, SDCascadeAttnBlock):
    557                             x = block(x, clip)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/diffusers/models/unets/unet_stable_cascade.py:74, in SDCascadeResBlock.forward(self, x, x_skip)
     72 def forward(self, x, x_skip=None):
     73     x_res = x
---> 74     x = self.norm(self.depthwise(x))
     75     if x_skip is not None:
     76         x = torch.cat([x, x_skip], dim=1)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File ~/anaconda3/envs/sd_bleeding/lib/python3.10/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

System Info

  • diffusers version: 0.27.0.dev0
  • Platform: Linux-6.8.0-11-generic-x86_64-with-glibc2.39
  • Python version: 3.10.8
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Huggingface_hub version: 0.20.2
  • Transformers version: 4.36.2
  • Accelerate version: 0.28.0
  • xFormers version: 0.0.24
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions