-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
When running a pipeline under 🤗 Accelerate, my M3 system crashes with a tensor dtype mismatch during validation inference.
This issue does not occur without 🤗 Accelerate.
Reproduction
Unfortunately, the way the reproducer is executed within a training script makes it exceedingly difficult/time-consuming to put together a small reproducer.
I added print statements to the pipeline so that I could determine where the broadcast compatibility error originates, and discovered this line:
print("compute the previous noisy sample x_t -> x_t-1")
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]which seems to bump the dtype from float16 to float32 on the 2nd step. the 1st step is correct and fine.
just on a whim, i modified the code to the following:
print("compute the previous noisy sample x_t -> x_t-1")
old_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype is not old_dtype:
latents = latents.to(old_dtype)which then allows validations to complete and finish.
i'm sure there's a deeper issue and this is not a valid solution, so i have not opened up a PR with this solution.
it seems like the DDIM scheduler simply should not change the dtype.
Logs
Before resolution:
2024-03-21 10:14:22,290 [DEBUG] (validation) Generating validation image:
Device allocations:
-> unet on mps:0
-> text_encoder on None
-> vae on mps:0
-> current_validation_prompt_embeds on mps:0
-> current_validation_pooled_embeds on mps:0
-> validation_negative_prompt_embeds on mps:0
-> validation_negative_pooled_embeds on mps:0
2024-03-21 10:14:22,361 [DEBUG] (validation) Generating validation image:
Weight dtypes:
-> unet: torch.float16
-> text_encoder: None
-> vae: torch.float16
-> current_validation_prompt_embeds: torch.float16
-> current_validation_pooled_embeds: torch.float16
-> validation_negative_prompt_embeds: torch.float16
-> validation_negative_pooled_embeds: torch.float16
2024-03-21 10:14:22,361 [DEBUG] (validation) Generating validation image:
-> Number of images: 1
-> Number of inference steps: 30
-> Guidance scale: 7.5
-> Guidance rescale: 0.7
-> Resolution: 1024.0
-> Extra validation kwargs: {'generator': <torch._C.Generator object at 0x6cb88de10>}
1. Check inputs. Raise error if not correct
Prompt embeds shape: torch.Size([1, 77, 2048])
Negative prompt embeds shape: torch.Size([1, 77, 2048])
2. Define call parameters
3. Encode input prompt
4. Prepare timesteps
5. Prepare latent variables
6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
7. Prepare added time ids & embeddings
8. Denoising loop
8.1 Apply denoising_end
9. Optionally get Guidance Scale Embedding
expand the latents if we are doing classifier free guidance
predict the noise residual
-> dtype: torch.float16
-> dtype (latents): torch.float16
-> dtype (prompt_embeds): torch.float16
-> dtype (add_text_embeds): torch.float16
perform guidance
Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
compute the previous noisy sample x_t -> x_t-1
call the callback, if provided
expand the latents if we are doing classifier free guidance
predict the noise residual
-> dtype: torch.float32
-> dtype (latents): torch.float32
-> dtype (prompt_embeds): torch.float16
-> dtype (add_text_embeds): torch.float16
loc("mps_add"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<2x1280xf32>' and 'tensor<1280xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).After resolution:
Weight dtypes:
-> unet: torch.float16
-> text_encoder: None
-> vae: torch.float16
-> current_validation_prompt_embeds: torch.float16
-> current_validation_pooled_embeds: torch.float16
-> validation_negative_prompt_embeds: torch.float16
-> validation_negative_pooled_embeds: torch.float16
2024-03-21 10:12:17,921 [DEBUG] (validation) Generating validation image:
-> Number of images: 1
-> Number of inference steps: 30
-> Guidance scale: 7.5
-> Guidance rescale: 0.7
-> Resolution: 1024.0
-> Extra validation kwargs: {'generator': <torch._C.Generator object at 0x35ae4be50>}
1. Check inputs. Raise error if not correct
Prompt embeds shape: torch.Size([1, 77, 2048])
Negative prompt embeds shape: torch.Size([1, 77, 2048])
2. Define call parameters
3. Encode input prompt
4. Prepare timesteps
5. Prepare latent variables
6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
7. Prepare added time ids & embeddings
8. Denoising loop
8.1 Apply denoising_end
9. Optionally get Guidance Scale Embedding
expand the latents if we are doing classifier free guidance
predict the noise residual
-> dtype: torch.float16
-> dtype (latents): torch.float16
-> dtype (prompt_embeds): torch.float16
-> dtype (add_text_embeds): torch.float16
perform guidance
Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
compute the previous noisy sample x_t -> x_t-1
call the callback, if provided
expand the latents if we are doing classifier free guidance
predict the noise residual
-> dtype: torch.float16
-> dtype (latents): torch.float16
-> dtype (prompt_embeds): torch.float16
-> dtype (add_text_embeds): torch.float16
System Info
diffusersversion: 0.26.3- Platform: macOS-14.4-arm64-arm-64bit
- Python version: 3.10.13
- PyTorch version (GPU?): 2.2.1 (False)
- Huggingface_hub version: 0.21.4
- Transformers version: 4.40.0.dev0
- Accelerate version: 0.26.1
- xFormers version: not installed
- Using GPU in script?: M3 Max
- Using distributed or parallel set-up in script?: False