Skip to content

Commit 2b67d57

Browse files
authored
Fix pdist numerical issue on large input size (#2181)
This is to fix #2021. The current algorithm uses the following formula to calculate index and is numerically unstable at cases when n goes large (and j=n-1). Need sqrt to be in double precision to ensure correct results after casting into int. This change also aligns with cuda implementation. https://github.com/intel/torch-xpu-ops/blob/779f89911779b8c7296aaec3cf74945c18acc270/src/ATen/native/xpu/sycl/DistanceKernels.cpp#L732 <img width="1322" height="138" alt="image" src="https://github.com/user-attachments/assets/b7fa1515-eb29-436b-b14e-ac2c393df730" />
1 parent e78fc05 commit 2b67d57

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

src/ATen/native/xpu/sycl/DistanceKernels.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
729729
const size_t stride = item_id.get_local_range().size();
730730

731731
int64_t i = static_cast<int64_t>(
732-
(n2_val_ - device_sqrt<accscalar_t>(n2_squared_minus_1_val_ - 2 * k)));
732+
(n2_val_ - device_sqrt<double>(n2_squared_minus_1_val_ - 2 * k)));
733733
int64_t j = k - n_ * i + i * (i + 1) / 2 + i + 1;
734734

735735
const scalar_t* const start = in_ptr + i * m_;
@@ -760,8 +760,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
760760
const int64_t n,
761761
const int64_t m,
762762
accscalar_t p_val,
763-
accscalar_t n2_val,
764-
accscalar_t n2_squared_minus_1_val,
763+
const double n2_val,
764+
const double n2_squared_minus_1_val,
765765
scalar_t* out_data,
766766
const scalar_t* in_data,
767767
const int64_t wgroup_size)
@@ -778,8 +778,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
778778
const int64_t n_;
779779
const int64_t m_;
780780
accscalar_t p_val_;
781-
accscalar_t n2_val_;
782-
accscalar_t n2_squared_minus_1_val_;
781+
const double n2_val_;
782+
const double n2_squared_minus_1_val_;
783783
scalar_t* out_data_;
784784
const scalar_t* in_data_;
785785
sycl_local_acc_t<scalar_t, 1> shared_;
@@ -805,8 +805,6 @@ static void pdist_kernel_impl(
805805
}
806806

807807
auto p_val = static_cast<accscalar_t>(p);
808-
auto n2_val = static_cast<accscalar_t>(n2);
809-
auto n2_squared_minus_1_val = static_cast<accscalar_t>(n2_squared_minus_1);
810808

811809
auto out_data = result.mutable_data_ptr<scalar_t>();
812810
auto in_data = self.const_data_ptr<scalar_t>();
@@ -815,8 +813,8 @@ static void pdist_kernel_impl(
815813
n,
816814
m,
817815
p_val,
818-
n2_val,
819-
n2_squared_minus_1_val,
816+
n2,
817+
n2_squared_minus_1,
820818
out_data,
821819
in_data,
822820
wgroup_size / min_sg_size);

0 commit comments

Comments
 (0)