Skip to content

[MPS] SDXL pipeline fails inference in fp16 mode #7426

@bghira

Description

@bghira

Describe the bug

When running a pipeline under 🤗 Accelerate, my M3 system crashes with a tensor dtype mismatch during validation inference.

This issue does not occur without 🤗 Accelerate.

Reproduction

Unfortunately, the way the reproducer is executed within a training script makes it exceedingly difficult/time-consuming to put together a small reproducer.

I added print statements to the pipeline so that I could determine where the broadcast compatibility error originates, and discovered this line:

                print("compute the previous noisy sample x_t -> x_t-1")
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

which seems to bump the dtype from float16 to float32 on the 2nd step. the 1st step is correct and fine.

just on a whim, i modified the code to the following:

                print("compute the previous noisy sample x_t -> x_t-1")
                old_dtype = latents.dtype
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                if latents.dtype is not old_dtype:
                    latents = latents.to(old_dtype)

which then allows validations to complete and finish.

i'm sure there's a deeper issue and this is not a valid solution, so i have not opened up a PR with this solution.

it seems like the DDIM scheduler simply should not change the dtype.

Logs

Before resolution:


2024-03-21 10:14:22,290 [DEBUG] (validation) Generating validation image: 
 Device allocations:
 -> unet on mps:0
 -> text_encoder on None
 -> vae on mps:0
 -> current_validation_prompt_embeds on mps:0
 -> current_validation_pooled_embeds on mps:0
 -> validation_negative_prompt_embeds on mps:0
 -> validation_negative_pooled_embeds on mps:0
2024-03-21 10:14:22,361 [DEBUG] (validation) Generating validation image: 
 Weight dtypes:
 -> unet: torch.float16
 -> text_encoder: None
 -> vae: torch.float16
 -> current_validation_prompt_embeds: torch.float16
 -> current_validation_pooled_embeds: torch.float16
 -> validation_negative_prompt_embeds: torch.float16
 -> validation_negative_pooled_embeds: torch.float16
2024-03-21 10:14:22,361 [DEBUG] (validation) Generating validation image: 
 -> Number of images: 1
 -> Number of inference steps: 30
 -> Guidance scale: 7.5
 -> Guidance rescale: 0.7
 -> Resolution: 1024.0
 -> Extra validation kwargs: {'generator': <torch._C.Generator object at 0x6cb88de10>}
1. Check inputs. Raise error if not correct
Prompt embeds shape: torch.Size([1, 77, 2048])
Negative prompt embeds shape: torch.Size([1, 77, 2048])
2. Define call parameters
3. Encode input prompt
4. Prepare timesteps
5. Prepare latent variables
6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
7. Prepare added time ids & embeddings
8. Denoising loop
8.1 Apply denoising_end
9. Optionally get Guidance Scale Embedding
expand the latents if we are doing classifier free guidance
predict the noise residual
 -> dtype: torch.float16
 -> dtype (latents): torch.float16
 -> dtype (prompt_embeds): torch.float16
 -> dtype (add_text_embeds): torch.float16
perform guidance
Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
compute the previous noisy sample x_t -> x_t-1
call the callback, if provided
expand the latents if we are doing classifier free guidance
predict the noise residual
 -> dtype: torch.float32
 -> dtype (latents): torch.float32
 -> dtype (prompt_embeds): torch.float16
 -> dtype (add_text_embeds): torch.float16
loc("mps_add"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<2x1280xf32>' and 'tensor<1280xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).

After resolution:

 Weight dtypes:
 -> unet: torch.float16
 -> text_encoder: None
 -> vae: torch.float16
 -> current_validation_prompt_embeds: torch.float16
 -> current_validation_pooled_embeds: torch.float16
 -> validation_negative_prompt_embeds: torch.float16
 -> validation_negative_pooled_embeds: torch.float16
2024-03-21 10:12:17,921 [DEBUG] (validation) Generating validation image:
 -> Number of images: 1
 -> Number of inference steps: 30
 -> Guidance scale: 7.5
 -> Guidance rescale: 0.7
 -> Resolution: 1024.0
 -> Extra validation kwargs: {'generator': <torch._C.Generator object at 0x35ae4be50>}
1. Check inputs. Raise error if not correct
Prompt embeds shape: torch.Size([1, 77, 2048])
Negative prompt embeds shape: torch.Size([1, 77, 2048])
2. Define call parameters
3. Encode input prompt
4. Prepare timesteps
5. Prepare latent variables
6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
7. Prepare added time ids & embeddings
8. Denoising loop
8.1 Apply denoising_end
9. Optionally get Guidance Scale Embedding
expand the latents if we are doing classifier free guidance
predict the noise residual
 -> dtype: torch.float16
 -> dtype (latents): torch.float16
 -> dtype (prompt_embeds): torch.float16
 -> dtype (add_text_embeds): torch.float16
perform guidance
Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
compute the previous noisy sample x_t -> x_t-1
call the callback, if provided
expand the latents if we are doing classifier free guidance
predict the noise residual
 -> dtype: torch.float16
 -> dtype (latents): torch.float16
 -> dtype (prompt_embeds): torch.float16
 -> dtype (add_text_embeds): torch.float16

System Info

  • diffusers version: 0.26.3
  • Platform: macOS-14.4-arm64-arm-64bit
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.2.1 (False)
  • Huggingface_hub version: 0.21.4
  • Transformers version: 4.40.0.dev0
  • Accelerate version: 0.26.1
  • xFormers version: not installed
  • Using GPU in script?: M3 Max
  • Using distributed or parallel set-up in script?: False

Who can help?

@pcuenca

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions