@@ -305,7 +305,7 @@ struct LegacyKernelWithCastScalarFunctor {
305305 const func_t f_;
306306};
307307
308- template <int vec_size, typename func_t >
308+ template <int vec_size, typename func_t , bool force_small_grf = false >
309309static void launch_legacy_group_range_kernel (int64_t N, const func_t & f) {
310310 TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
311311 if (N == 0 ) {
@@ -316,7 +316,12 @@ static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) {
316316
317317 int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
318318 int64_t num_wg = ceil_div<int64_t >(N, wg_sz * vec_size);
319- sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
319+ if constexpr (force_small_grf) {
320+ sycl_kernel_submit_small_grf (
321+ wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
322+ } else {
323+ sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
324+ }
320325}
321326
322327template <typename func_t >
@@ -341,7 +346,8 @@ template <
341346 typename in_calc_t ,
342347 typename out_calc_t ,
343348 typename loader_t ,
344- typename storer_t >
349+ typename storer_t ,
350+ bool force_small_grf = false >
345351static inline void launch_unrolled_kernel (
346352 int64_t N,
347353 const func_t & f,
@@ -357,7 +363,12 @@ static inline void launch_unrolled_kernel(
357363
358364 int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
359365 int64_t num_wg = ceil_div<int64_t >(N, wg_sz * ker_t ::item_work_size);
360- sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
366+ if constexpr (force_small_grf) {
367+ sycl_kernel_submit_small_grf (
368+ wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
369+ } else {
370+ sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
371+ }
361372}
362373
363374constexpr int max_scalar_size_ (std::tuple<>) {
@@ -570,7 +581,14 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
570581 auto storer = memory::StoreWithCast<1 >(iter);
571582 auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
572583 auto output_offset_calculator = TrivialOffsetCalculator<1 >();
573- launch_unrolled_kernel (
584+ launch_unrolled_kernel<
585+ func_t ,
586+ decltype (data),
587+ decltype (input_offset_calculator),
588+ decltype (output_offset_calculator),
589+ decltype (loader),
590+ decltype (storer),
591+ true >(
574592 numel,
575593 f,
576594 data,
@@ -585,13 +603,13 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
585603 }
586604 auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
587605 constexpr int unroll_factor = sizeof (arg0_t ) > 4 ? 2 : 4 ;
588- launch_legacy_group_range_kernel<unroll_factor>(
589- numel ,
590- LegacyKernelWithCastScalarFunctor<
591- arg0_t ,
592- ntensors,
593- decltype (offset_calc),
594- func_t > (data, dtypes, offset_calc, f));
606+ using functor = LegacyKernelWithCastScalarFunctor<
607+ arg0_t ,
608+ ntensors,
609+ decltype (offset_calc) ,
610+ func_t >;
611+ launch_legacy_group_range_kernel<unroll_factor, functor, true >(
612+ numel, functor (data, dtypes, offset_calc, f));
595613 }
596614}
597615
0 commit comments