-
Notifications
You must be signed in to change notification settings - Fork 61
Enable FP8 concat_xpu and where_xpu #2152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
747995c
Update Shape.cpp
yucai-intel 3f9538a
Update TensorCompareKernels.cpp
yucai-intel 9546e78
Update test_cat.py
yucai-intel 1653fe7
Merge branch 'main' into yucai/fp8
yucai-intel 7f89018
Update test_cat.py
yucai-intel 95ac047
format
yucai-intel 2e2abe3
format
yucai-intel 226ccb3
format
yucai-intel 652ae58
format
yucai-intel a115cb2
format
yucai-intel 6ae6661
Create test_where.py
yucai-intel fbaf98f
format
yucai-intel 5ebda52
format
yucai-intel 85eaed6
Update test/regressions/test_where.py
CuiYifeng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # Owner(s): ["module: intel"] | ||
| import torch | ||
| from torch.testing._internal.common_utils import TestCase | ||
|
|
||
|
|
||
| class TestTorchWhereMethod(TestCase): | ||
| # Define float8 dtypes | ||
| FLOAT8_DTYPES = ( | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e4m3fnuz, | ||
| torch.float8_e5m2, | ||
| torch.float8_e5m2fnuz, | ||
| torch.float8_e8m0fnu, | ||
| ) | ||
|
|
||
| # Define the set of all dtypes to be tested | ||
| TEST_DTYPES = ( | ||
| torch.float32, | ||
| torch.float64, | ||
| torch.half, | ||
| torch.bfloat16, | ||
| ) + FLOAT8_DTYPES | ||
|
|
||
| def _test_where_fn(self, dtype): | ||
| """Core function to test torch.where(condition, x, y) correctness.""" | ||
|
|
||
| # 1. Input Tensors (x and y) | ||
| x = torch.tensor([[10.0, 20.0], [30.0, 40.0]], dtype=dtype) | ||
| y = torch.tensor([[-1.0, -2.0], [-3.0, -4.0]], dtype=dtype) | ||
| # Condition must be bool | ||
| condition = torch.tensor([[True, False], [False, True]], dtype=torch.bool) | ||
|
|
||
| # --- 1. CPU Reference Calculation and Tolerance Setting --- | ||
|
|
||
| if dtype in self.FLOAT8_DTYPES: | ||
| # FP8: Use float32 as reference type for comparison | ||
| x_ref = x.cpu().to(torch.float32) | ||
| y_ref = y.cpu().to(torch.float32) | ||
| rtol = 1e-2 | ||
| atol = 1e-2 | ||
| else: | ||
| # Non-FP8: Use original dtype as reference type | ||
| x_ref = x.cpu() | ||
| y_ref = y.cpu() | ||
| rtol = 1e-5 | ||
| atol = 1e-5 | ||
|
|
||
| condition_ref = condition.cpu() | ||
| res_ref = torch.where(condition_ref, x_ref, y_ref) | ||
|
|
||
| # --- 2. XPU Operation (Default) --- | ||
| x_xpu = x.xpu() | ||
| y_xpu = y.xpu() | ||
| condition_xpu = condition.xpu() | ||
|
|
||
| res_xpu = torch.where(condition_xpu, x_xpu, y_xpu) | ||
|
|
||
| # Prepare XPU result for comparison (must match res_ref dtype) | ||
| if dtype in self.FLOAT8_DTYPES: | ||
| # FP8: Convert XPU result to float32 | ||
| res_xpu_to_compare = res_xpu.cpu().to(torch.float32) | ||
| else: | ||
| # Non-FP8: Pull to CPU, keeping original dtype | ||
| res_xpu_to_compare = res_xpu.cpu() | ||
|
|
||
| # Compare: res_ref vs res_xpu_to_compare | ||
| self.assertEqual(res_ref, res_xpu_to_compare, rtol=rtol, atol=atol) | ||
|
|
||
| # --- 3. Test the version with out= argument --- | ||
|
|
||
| # Create output tensor on XPU | ||
| res_xpu_out = torch.empty_like(res_xpu, dtype=dtype).xpu() | ||
| torch.where(condition_xpu, x_xpu, y_xpu, out=res_xpu_out) | ||
|
|
||
| # Prepare XPU 'out' result for comparison | ||
| if dtype in self.FLOAT8_DTYPES: | ||
| # FP8: Convert XPU result to float32 | ||
| res_xpu_out_to_compare = res_xpu_out.cpu().to(torch.float32) | ||
| else: | ||
| # Non-FP8: Pull to CPU, keeping original dtype | ||
| res_xpu_out_to_compare = res_xpu_out.cpu() | ||
|
|
||
| # Compare: res_ref vs res_xpu_out_to_compare | ||
| self.assertEqual(res_ref, res_xpu_out_to_compare, rtol=rtol, atol=atol) | ||
|
|
||
| def test_where(self): | ||
| """Test torch.where() correctness across all supported dtypes, including float8.""" | ||
| for dtype in self.TEST_DTYPES: | ||
| # Use string conversion for better subTest reporting | ||
| dtype_name = str(dtype).split(".")[-1] | ||
| with self.subTest(dtype=dtype_name): | ||
| self._test_where_fn(dtype) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.