@@ -9,27 +9,26 @@ namespace at::native::xpu {
99
1010template <typename scalar_t >
1111struct IgammaFunctor {
12- scalar_t operator ()(scalar_t a, scalar_t b) const {
13- return calc_igamma<scalar_t >(a, b);
14- }
15- };
16-
17- template <typename scalar_t >
18- struct IgammacFunctor {
19- scalar_t operator ()(scalar_t a, scalar_t b) const {
20- return calc_igammac<scalar_t >(a, b);
12+ IgammaFunctor (bool calc_igammac) : calc_igammac_(calc_igammac) {}
13+ bool calc_igammac_;
14+ [[clang::optnone]] scalar_t operator ()(scalar_t a, scalar_t b) const {
15+ if (calc_igammac_) {
16+ return calc_igammac<scalar_t >(a, b);
17+ } else {
18+ return calc_igamma<scalar_t >(a, b);
19+ }
2120 }
2221};
2322
2423void igamma_kernel (TensorIteratorBase& iter) {
2524 AT_DISPATCH_FLOATING_TYPES (iter.common_dtype (), " igamma_xpu" , [&]() {
26- gpu_kernel (iter, IgammaFunctor<scalar_t >());
25+ gpu_kernel (iter, IgammaFunctor<scalar_t >(false ));
2726 });
2827}
2928
3029void igammac_kernel (TensorIteratorBase& iter) {
3130 AT_DISPATCH_FLOATING_TYPES (iter.common_dtype (), " igammac_xpu" , [&]() {
32- gpu_kernel (iter, IgammacFunctor <scalar_t >());
31+ gpu_kernel (iter, IgammaFunctor <scalar_t >(true ));
3332 });
3433}
3534
0 commit comments