|
7 | 7 | from torch.testing._internal.common_dtype import float8_types_and |
8 | 8 | from torch.testing._internal.common_utils import run_tests, TestCase |
9 | 9 |
|
| 10 | +cpu_device = torch.device("cpu") |
| 11 | +xpu_device = torch.device("xpu") |
| 12 | + |
10 | 13 |
|
11 | 14 | class TestTorchMethod(TestCase): |
12 | 15 | def _create_input_tensors(self, shape, dtype, memory_format=None): |
@@ -61,6 +64,21 @@ def test_cat_simple(self, dtype): |
61 | 64 |
|
62 | 65 | self._test_cat_float8_core(tensors, dim, dtype) |
63 | 66 |
|
| 67 | + def _float4_dummy_tensor(self, shape, device): |
| 68 | + data = torch.ones(shape, dtype=torch.uint8, device=device) |
| 69 | + return data.view(torch.float4_e2m1fn_x2) |
| 70 | + |
| 71 | + def test_cat_float4_simple(self): |
| 72 | + input_cpu1 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device) |
| 73 | + input_cpu2 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device) |
| 74 | + output_cpu = torch.stack([input_cpu1, input_cpu2]).view(torch.uint8) |
| 75 | + |
| 76 | + input_xpu1 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device) |
| 77 | + input_xpu2 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device) |
| 78 | + output_xpu = torch.stack([input_xpu1, input_xpu2]).view(torch.uint8) |
| 79 | + |
| 80 | + self.assertEqual(output_xpu, output_cpu) |
| 81 | + |
64 | 82 | def test_cat_8d(self, dtype=torch.float): |
65 | 83 | input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype) |
66 | 84 | input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype) |
|
0 commit comments