diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index e4a8444c9..c68188005 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -1063,8 +1063,10 @@ void _layer_norm_backward_kernel( norm_config_global_size / syclMaxSubGroupSize() * 2 <= thread_slots; // cuda uses condition M > 64 * 1024 && N / 32 < sm_count / 2 to parallelize // in the M dimension - if (use_two_stage_col_reduction && M > 64 * 1024 && - N / 32 < syclGpuEuCount() / syclGpuEUCountPerSubslice() / 2) { + int xe_core_count = syclGpuEuCount() / syclGpuEUCountPerSubslice(); + int tile_n = N / 32; + if (use_two_stage_col_reduction && M > xe_core_count * 1024 && + tile_n < xe_core_count * 2) { const size_t local_size_x = 8; const size_t SIMD = 32; // workgroup size is 256