Skip to content

Skip inconsistent tests for MPS in test_unet_2d_blocks #2717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

apivovarov
Copy link
Contributor

All recent PRs validations failed because two test_unet_2d_blocks.py tests give inconsistent result on Mac (MPS device)

FAILED tests/test_unet_2d_blocks.py::SimpleCrossAttnUpBlock2DTests::test_output - AssertionError: Max diff is absolute 1.2456624507904053. Diff tensor is tensor([0.1679, 0.4054, 0.3008, 0.2366, 1.2457, 1.0652, 0.9767, 0.6834, 0.3251],
       device='mps:0').
FAILED tests/test_unet_2d_blocks.py::AttnUpDecoderBlock2DTests::test_output - AssertionError: Max diff is absolute 1.7274205684661865. Diff tensor is tensor([1.0407, 0.7878, 0.0026, 1.7274, 0.4588, 0.6572, 0.2625, 0.2807, 0.1914],
       device='mps:0').

I tried to run these tests on Mac

SimpleCrossAttnUpBlock2DTests randomly gives the following two outputs:

[ 0.5771,  0.5599,  0.1788,  0.2298, -0.1610, -0.6679, -0.3543, -0.2655, 0.1562]
[ 0.1935,  0.2100,  0.1523,  0.3923, -0.0860, -0.0189, -0.3123, -0.2552, 0.1499]

AttnUpDecoderBlock2DTests randomly gives the following two outputs:

[-0.0069,  0.1846, -0.1053,  0.0273,  0.0236,  0.0067,  0.0031,  0.1628, 0.0043]
[0.0422, 0.0423, 0.0418, 0.0422, 0.0422, 0.0420, 0.0423, 0.0424, 0.0419]

Two other tests in this file are already skipped if device is mps - SimpleCrossAttnDownBlock2DTests and AttnUpBlock2DTests

We can skip skip SimpleCrossAttnUpBlock2DTests and AttnUpDecoderBlock2DTests on mps device too.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

cc @pcuenca can you check here?

@pcuenca
Copy link
Member

pcuenca commented Mar 21, 2023

Hi @apivovarov! Actually those tests are now consistent and pass when using PyTorch 2, which is in use by our CI system. I'm really sorry I saw this notification after we had already merged #2766, otherwise I would have added you as an author :(

Going forward I'd recommend to upgrade to PyTorch 2 if you can, as it fixes a few known issues in the mps device. As a matter of fact, some of the nasty "warmup" passes that were previously required to obtain consistent results are probably not needed anymore; I'll test it today.

@pcuenca
Copy link
Member

pcuenca commented Mar 21, 2023

See #2771 for an update on the warmup passes in PyTorch 2.

@apivovarov
Copy link
Contributor Author

Got it. Thank you!

@apivovarov apivovarov closed this Mar 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants