Skip to content

Conversation

@jsmidt
Copy link
Contributor

@jsmidt jsmidt commented Aug 9, 2024

What does this PR do?

dtype=torch.float64 is overkill, and float64 is not defined for certain devices such as Apple Silicon mps. This change enables the flux pipeline to be run on certain devices such as Apple Silicon mps without negative consequences.

jsmidt added 2 commits August 8, 2024 19:15
dtype=torch.float64 is overkill, and float64 is not defined for certain devices such as Apple Silicon mps.
Update transformer_flux.py. Change float64 to float32
@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

just a note that macos 14 and pytorch 2.4 or greater still can't do it. but i think macos 15 can, or, pytorch 2.3.1 with macos 14. but then training uses a lot more vram.

edit: no, macos 15 still broken - don't upgrade to fix it.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hvaara
Copy link
Contributor

hvaara commented Aug 13, 2024

This fix should make the model runnable on MPS. Tested with https://gist.github.com/hvaara/bc8754b2aab6ef07a95c82c5e436f6d3. Running macOS 14.6. transformers and diffusers from the main branch.

@bghira Does it work for you with the patch from this PR and the code form my gist? Need ~45 GB VRAM. If not, what error are you seeing?

@bghira
Copy link
Contributor

bghira commented Aug 13, 2024

oh roy 🙉 we meet again. i see no error, it's just that the image is all noise

@hvaara
Copy link
Contributor

hvaara commented Aug 14, 2024

Haha! Indeed we do 😂

Are you using latest diffusers and transformers? Are you sure your weights are not corrupted? With the code change by OP and the script in my gist I get great images using MPS as the accelerator.

I actually came here to contrib the exact same change as OP 😅

@bghira
Copy link
Contributor

bghira commented Aug 14, 2024

all day every day on git branches.. latest and hot off the wire. pytorch nightly, latest diffusers/transformers, macos 15 beta 4.

@hvaara
Copy link
Contributor

hvaara commented Aug 14, 2024

Issues @bghira experienced has been identified as a bug in PyTorch. I will open a bug and propose a fix upstream.

@hvaara
Copy link
Contributor

hvaara commented Aug 15, 2024

Follow pytorch/pytorch#133520 for updates on the noisy output image issue.

@sayakpaul
Copy link
Member

Does this PR reliably solve the problem?

Cc: @DN6 @pcuenca

@hvaara
Copy link
Contributor

hvaara commented Aug 16, 2024

Yes. The only thing I would consider is the precision reduction. This has been solved in the past by predicating on the device and only reducing precision for MPS.

Prior art: #1169 #6365 #942

@sayakpaul
Copy link
Member

Thanks!

This has been solved in the past by predicating on the device and only reducing precision for MPS.

I would advocate for this myself and also perhaps logging a warning that we're reducing the precision here and results may be unexpected and refer to this PR. WDYT?

@hvaara
Copy link
Contributor

hvaara commented Aug 16, 2024

Did some testing. I don't know how much of an impact it has, but overall I think the images generated with float64 look the best.

A cat

Prompt

A cat holding a sign that says hello world

float16

output_image_mps_1_A cat holding a sign that says hello world_newtorch_16

float32

output_image_mps_1_A cat holding a sign that says hello world_newtorch_32

float64

output_image_mps_1_A cat holding a sign that says hello world_newtorch_64

A landscape

Prompt

Highly detailed fantasy landscape at golden hour in an ancient forest. Towering trees with glowing runes are illuminated by warm light, casting shadows on a mossy floor dotted with glowing flowers.

A clear stream winds through the foreground, reflecting the sky's hues and surrounded by glowing mushrooms. Fairy-like creatures with translucent wings flutter above, leaving shimmering trails.

In the background

(CLIP cut me off)

float16

output_image_mps_1_longprompt_newtorch_16

float32

output_image_mps_1_longprompt_newtorch_32

float64

output_image_mps_1_longprompt_newtorch_64

@DN6
Copy link
Collaborator

DN6 commented Aug 16, 2024

Yeah agree with @hvaara. @jsmidt could we just add a check for MPS device and then downcast to FP32?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 17, 2024

cc @asomoza here
can you take a look (or possibly run more test) to see if there is any difference

my eye couldn't spot any quality difference between the float64 output and float32 outputs

context is I'm refactoring flux to use get_1d_rotary_pos_embed here #9074, which use float32 - I think if float64 indeed generate better results, we might want to apply that for all models use rotary embedding (and of course downcast for mps);

@bghira
Copy link
Contributor

bghira commented Aug 17, 2024

some models do better with fp16 rope embeds. a transformer config option maybe?

@asomoza
Copy link
Member

asomoza commented Aug 19, 2024

can you take a look (or possibly run more test) to see if there is any difference

@yiyixuxu, I did some tests with flux dev and I also don't see a really big difference but if I have to choose one, I also think that float64 is better if I really look into some tiny details.

yiyixuxu pushed a commit that referenced this pull request Aug 19, 2024
@yiyixuxu yiyixuxu mentioned this pull request Aug 19, 2024
2 tasks
yiyixuxu added a commit that referenced this pull request Aug 21, 2024
* refactor rotary embeds

* adding jsmidt as co-author of this PR for #9133

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Joseph Smidt <[email protected]>
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@beaugunderson
Copy link

did anything remove the need for this on mps?

@hvaara
Copy link
Contributor

hvaara commented Sep 14, 2024

Yes, #9074 solved it. This PR can be closed.

@sayakpaul sayakpaul closed this Sep 14, 2024
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* refactor rotary embeds

* adding jsmidt as co-author of this PR for #9133

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Joseph Smidt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants