Skip to content
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ void cat_out_kernel(
kHalf,
kBool,
kBFloat16,
AT_EXPAND(AT_FLOAT8_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
} else {
offset = 0;
Expand Down
15 changes: 11 additions & 4 deletions src/ATen/native/xpu/sycl/TensorCompareKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>
Expand Down Expand Up @@ -78,10 +79,16 @@ struct ClampScalarFunctor {
};

void where_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] {
gpu_kernel(iter, WhereFunctor<scalar_t>());
});
AT_DISPATCH_V2(
iter.dtype(),
"where_xpu",
[&] { gpu_kernel(iter, WhereFunctor<scalar_t>()); },
kComplexHalf,
kHalf,
kBFloat16,
kBool,
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES));
}

void isposinf_kernel(TensorIteratorBase& iter) {
Expand Down
61 changes: 61 additions & 0 deletions test/regressions/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,67 @@


class TestTorchMethod(TestCase):
# Define float8 dtypes for the focused test
FLOAT8_DTYPES = (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
)

def _create_input_tensors(self, shape, dtype, memory_format=None):
# Always generate random data using a CPU-compatible dtype (float32)
# to avoid the "not implemented" error for float8 on CPU.
tensor = torch.randn(shape, dtype=torch.float32)

# Convert to the target testing dtype
tensor = tensor.to(dtype)

# Apply memory format if specified
if memory_format is not None:
tensor = tensor.to(memory_format=memory_format)

return tensor

def _test_cat_float8_core(self, tensors, dim, dtype):
"""Core function to test torch.cat for float8, using tolerances."""

# --- CPU Reference Calculation (High Precision) ---
# Convert inputs to float32 on CPU for golden reference calculation
ref_tensors = [t.cpu().to(torch.float32) for t in tensors]

# Calculate CPU reference result
res_cpu = torch.cat(ref_tensors, dim=dim)

# --- XPU Calculation ---
# Convert inputs to XPU
xpu_tensors = [t.xpu() for t in tensors]
res_xpu = torch.cat(xpu_tensors, dim=dim)

# Float8 is lossy, use higher tolerance (rtol=1e-2, atol=1e-2)
rtol = 1e-2
atol = 1e-2

# Convert XPU result to float32 on CPU before comparison to match res_cpu's dtype.
res_xpu_f32_on_cpu = res_xpu.cpu().to(torch.float32)

self.assertEqual(res_cpu, res_xpu_f32_on_cpu, rtol=rtol, atol=atol)

def test_cat_float8_simple(self):
"""Test torch.cat correctness across float8 dtypes using simple tensors."""
for dtype in self.FLOAT8_DTYPES:
with self.subTest(dtype=dtype):
# Use simple 3D shape (2, 4, 3) and concatenate along dim 1
user_cpu1 = self._create_input_tensors([2, 4, 3], dtype=dtype)
user_cpu2 = self._create_input_tensors([2, 2, 3], dtype=dtype)
user_cpu3 = self._create_input_tensors([2, 6, 3], dtype=dtype)

tensors = (user_cpu1, user_cpu2, user_cpu3)
dim = 1

self._test_cat_float8_core(tensors, dim, dtype)

def test_cat_8d(self, dtype=torch.float):
input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)
input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)
Expand Down
92 changes: 92 additions & 0 deletions test/regressions/test_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Owner(s): ["module: intel"]
import torch
from torch.testing._internal.common_utils import TestCase


class TestTorchWhereMethod(TestCase):
# Define float8 dtypes
FLOAT8_DTYPES = (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
)

# Define the set of all dtypes to be tested
TEST_DTYPES = (
torch.float32,
torch.float64,
torch.half,
torch.bfloat16,
) + FLOAT8_DTYPES

def _test_where_fn(self, dtype):
"""Core function to test torch.where(condition, x, y) correctness."""

# 1. Input Tensors (x and y)
x = torch.tensor([[10.0, 20.0], [30.0, 40.0]], dtype=dtype)
y = torch.tensor([[-1.0, -2.0], [-3.0, -4.0]], dtype=dtype)
# Condition must be bool
condition = torch.tensor([[True, False], [False, True]], dtype=torch.bool)

# --- 1. CPU Reference Calculation and Tolerance Setting ---

if dtype in self.FLOAT8_DTYPES:
# FP8: Use float32 as reference type for comparison
x_ref = x.cpu().to(torch.float32)
y_ref = y.cpu().to(torch.float32)
rtol = 1e-2
atol = 1e-2
else:
# Non-FP8: Use original dtype as reference type
x_ref = x.cpu()
y_ref = y.cpu()
rtol = 1e-5
atol = 1e-5

condition_ref = condition.cpu()
res_ref = torch.where(condition_ref, x_ref, y_ref)

# --- 2. XPU Operation (Default) ---
x_xpu = x.xpu()
y_xpu = y.xpu()
condition_xpu = condition.xpu()

res_xpu = torch.where(condition_xpu, x_xpu, y_xpu)

# Prepare XPU result for comparison (must match res_ref dtype)
if dtype in self.FLOAT8_DTYPES:
# FP8: Convert XPU result to float32
res_xpu_to_compare = res_xpu.cpu().to(torch.float32)
else:
# Non-FP8: Pull to CPU, keeping original dtype
res_xpu_to_compare = res_xpu.cpu()

# Compare: res_ref vs res_xpu_to_compare
self.assertEqual(res_ref, res_xpu_to_compare, rtol=rtol, atol=atol)

# --- 3. Test the version with out= argument ---

# Create output tensor on XPU
res_xpu_out = torch.empty_like(res_xpu, dtype=dtype).xpu()
torch.where(condition_xpu, x_xpu, y_xpu, out=res_xpu_out)

# Prepare XPU 'out' result for comparison
if dtype in self.FLOAT8_DTYPES:
# FP8: Convert XPU result to float32
res_xpu_out_to_compare = res_xpu_out.cpu().to(torch.float32)
else:
# Non-FP8: Pull to CPU, keeping original dtype
res_xpu_out_to_compare = res_xpu_out.cpu()

# Compare: res_ref vs res_xpu_out_to_compare
self.assertEqual(res_ref, res_xpu_out_to_compare, rtol=rtol, atol=atol)

def test_where(self):
"""Test torch.where() correctness across all supported dtypes, including float8."""
for dtype in self.TEST_DTYPES:
# Use string conversion for better subTest reporting
dtype_name = str(dtype).split(".")[-1]
with self.subTest(dtype=dtype_name):
self._test_where_fn(dtype)