Skip to content

Commit 993ab70

Browse files
authored
Add float4_e2m1fn_x2 support for concat (#2315)
This PR adds support for the `float4_e2m1fn_x2` data type to the `cat` (concatenate) kernel on XPU devices.
1 parent 88743cd commit 993ab70

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/ATen/native/xpu/sycl/Shape.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ void cat_out_kernel(
395395
kBool,
396396
kBFloat16,
397397
AT_EXPAND(AT_FLOAT8_TYPES),
398-
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
398+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
399+
kFloat4_e2m1fn_x2);
399400
} else {
400401
offset = 0;
401402
for (j = 0; j < numInputs; j++) {

test/regressions/test_cat.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from torch.testing._internal.common_dtype import float8_types_and
88
from torch.testing._internal.common_utils import run_tests, TestCase
99

10+
cpu_device = torch.device("cpu")
11+
xpu_device = torch.device("xpu")
12+
1013

1114
class TestTorchMethod(TestCase):
1215
def _create_input_tensors(self, shape, dtype, memory_format=None):
@@ -61,6 +64,21 @@ def test_cat_simple(self, dtype):
6164

6265
self._test_cat_float8_core(tensors, dim, dtype)
6366

67+
def _float4_dummy_tensor(self, shape, device):
68+
data = torch.ones(shape, dtype=torch.uint8, device=device)
69+
return data.view(torch.float4_e2m1fn_x2)
70+
71+
def test_cat_float4_simple(self):
72+
input_cpu1 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device)
73+
input_cpu2 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device)
74+
output_cpu = torch.stack([input_cpu1, input_cpu2]).view(torch.uint8)
75+
76+
input_xpu1 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device)
77+
input_xpu2 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device)
78+
output_xpu = torch.stack([input_xpu1, input_xpu2]).view(torch.uint8)
79+
80+
self.assertEqual(output_xpu, output_cpu)
81+
6482
def test_cat_8d(self, dtype=torch.float):
6583
input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)
6684
input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)

0 commit comments

Comments
 (0)