Skip to content

Commit 20baf08

Browse files
Temporary Fix for FP16 -> FP8 conversion failure on -0.0 (#2387)
To resolve #2219 This PR is to temporarily work around the issue where FP16's -0.0 is erroneously converted to NaN during certain fusion passes (fp16 -> fp32 -> fp8), we are currently avoiding the use of the **sycl::half** data type in the intermediate conversion steps to prevent the problematic fusion from occurring. --------- Co-authored-by: Cui, Yifeng <[email protected]>
1 parent 0c85351 commit 20baf08

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

src/ATen/native/xpu/sycl/CopyKernel.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ struct CastScalarFunc {
2525
}
2626
};
2727

28+
// TODO: Avoid using sycl::half to prevent the fp16->fp32->fp8 fusion
29+
// from incorrectly converting -0.0 to NaN. This temporary fix should
30+
// be removed once the compiler/driver error is resolved.
31+
template <typename Float8DataType>
32+
struct CastScalarFunc<Half, Float8DataType> {
33+
Float8DataType operator()(Half src_val) const {
34+
Half val = src_val == Half(-0.0) ? Half(0.0) : src_val;
35+
return Float8DataType(val);
36+
}
37+
};
38+
2839
void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
2940
ScalarType dtype = iter.dtype(0);
3041
ScalarType other_dtype = iter.dtype(1);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Owner(s): ["module: intel"]
2+
import torch
3+
from torch.testing._internal.common_device_type import (
4+
dtypes,
5+
instantiate_device_type_tests,
6+
)
7+
from torch.testing._internal.common_dtype import float8_types
8+
from torch.testing._internal.common_utils import run_tests, TestCase
9+
10+
cpu_device = torch.device("cpu")
11+
xpu_device = torch.device("xpu")
12+
13+
14+
class TestSimpleConversion(TestCase):
15+
def _compare_convert_with_cpu(self, src_cpu, dtype):
16+
src_xpu = src_cpu.to(xpu_device)
17+
dst_cpu = src_cpu.to(dtype)
18+
dst_xpu = src_xpu.to(dtype)
19+
self.assertEqual(dst_xpu.to(cpu_device), dst_cpu)
20+
21+
@dtypes(*float8_types())
22+
def test_half_zero(self, dtype):
23+
pos_zero_fp16_cpu = torch.zeros((5, 6), dtype=torch.float16)
24+
self._compare_convert_with_cpu(pos_zero_fp16_cpu, dtype)
25+
26+
neg_zero_fp16_cpu = torch.full((5, 6), -0.0, dtype=torch.float16)
27+
self._compare_convert_with_cpu(neg_zero_fp16_cpu, dtype)
28+
29+
@dtypes(*float8_types())
30+
def test_half_nonzero(self, dtype):
31+
x_fp16_cpu = torch.arange(-100.0, 101.0, dtype=torch.float16)
32+
self._compare_convert_with_cpu(x_fp16_cpu, dtype)
33+
34+
35+
instantiate_device_type_tests(
36+
TestSimpleConversion, globals(), only_for="xpu", allow_xpu=True
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)