diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index e4a8444c94..87072b765a 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -181,7 +181,7 @@ bool can_vectorize(const T* ptr, int alignment) { return addr % alignment == 0; }; -template +template struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using WelfordType = WelfordData; using WelfordOp = WelfordOps>; @@ -204,8 +204,12 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { T_ACC m1; T_ACC m2; std::tie(m2, m1) = welford_op.project(val); - mean_[i] = m1; - rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); + if constexpr (!rms_norm) { + mean_[i] = m1; + rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); + } else { + rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_); + } } } @@ -249,7 +253,7 @@ void launch_rowwise_moments_kernel( sycl_kernel_submit(global_range, local_range, queue, kfn); } -template +template struct LayerNormForwardKernelFunctor { void operator()(sycl::nd_item<1> item_id) const { const int64_t i = item_id.get_group(0); @@ -258,12 +262,17 @@ struct LayerNormForwardKernelFunctor { const int64_t index = i * N_ + j; const T_ACC gamma_v = gamma_ == nullptr ? T_ACC(1) : static_cast(gamma_[j]); - const T_ACC beta_v = - beta_ == nullptr ? T_ACC(0) : static_cast(beta_[j]); - Y_[index] = - (static_cast(X_[index]) - static_cast(mean_[i])) * - static_cast(rstd_[i]) * gamma_v + - beta_v; + if constexpr (!rms_norm) { + const T_ACC beta_v = + beta_ == nullptr ? T_ACC(0) : static_cast(beta_[j]); + Y_[index] = + (static_cast(X_[index]) - static_cast(mean_[i])) * + static_cast(rstd_[i]) * gamma_v + + beta_v; + } else { + Y_[index] = (static_cast(X_[index])) * + static_cast(rstd_[i]) * gamma_v; + } } } LayerNormForwardKernelFunctor( @@ -323,17 +332,17 @@ struct WelfordDataLN { : mean(mean), sigma2(sigma2), count(count) {} }; -template +template WelfordDataLN WelfordOnlineSum(const U val, const WelfordDataLN& curr_sum) { - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + - delta * (1.f / new_count); // proper division is slow, this is less - // accurate but noticeably faster - return { - static_cast(new_mean), - static_cast(curr_sum.sigma2 + delta * (val - new_mean)), - static_cast(new_count)}; + if constexpr (!rms_norm) { + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + // proper division is slow, this is less accurate but noticeably faster + U new_mean = curr_sum.mean + delta * (1.f / new_count); + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + } else { + return {0.f, curr_sum.sigma2 + val * val, 0}; + } } WelfordDataLN WelfordCombine( diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a3281791de..0124d77042 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -3272,6 +3272,14 @@ autogen: native_layer_norm_backward.out tags: core +- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + dispatch: + XPU: _fused_rms_norm_xpu + +- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + XPU: _fused_rms_norm_backward_xpu + - func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: