Skip to content

Commit 4f058c2

Browse files
committed
set small grf for dynamic cast
1 parent 5b9b8d2 commit 4f058c2

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

src/ATen/native/xpu/sycl/Loops.h

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ struct LegacyKernelWithCastScalarFunctor {
305305
const func_t f_;
306306
};
307307

308-
template <int vec_size, typename func_t>
308+
template <int vec_size, typename func_t, bool force_small_grf = false>
309309
static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) {
310310
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
311311
if (N == 0) {
@@ -316,7 +316,12 @@ static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) {
316316

317317
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
318318
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * vec_size);
319-
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
319+
if constexpr (force_small_grf) {
320+
sycl_kernel_submit_small_grf(
321+
wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
322+
} else {
323+
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
324+
}
320325
}
321326

322327
template <typename func_t>
@@ -341,7 +346,8 @@ template <
341346
typename in_calc_t,
342347
typename out_calc_t,
343348
typename loader_t,
344-
typename storer_t>
349+
typename storer_t,
350+
bool force_small_grf = false>
345351
static inline void launch_unrolled_kernel(
346352
int64_t N,
347353
const func_t& f,
@@ -357,7 +363,12 @@ static inline void launch_unrolled_kernel(
357363

358364
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
359365
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * ker_t::item_work_size);
360-
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
366+
if constexpr (force_small_grf) {
367+
sycl_kernel_submit_small_grf(
368+
wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
369+
} else {
370+
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
371+
}
361372
}
362373

363374
constexpr int max_scalar_size_(std::tuple<>) {
@@ -570,7 +581,14 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
570581
auto storer = memory::StoreWithCast<1>(iter);
571582
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
572583
auto output_offset_calculator = TrivialOffsetCalculator<1>();
573-
launch_unrolled_kernel(
584+
launch_unrolled_kernel<
585+
func_t,
586+
decltype(data),
587+
decltype(input_offset_calculator),
588+
decltype(output_offset_calculator),
589+
decltype(loader),
590+
decltype(storer),
591+
true>(
574592
numel,
575593
f,
576594
data,
@@ -585,13 +603,13 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
585603
}
586604
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
587605
constexpr int unroll_factor = sizeof(arg0_t) > 4 ? 2 : 4;
588-
launch_legacy_group_range_kernel<unroll_factor>(
589-
numel,
590-
LegacyKernelWithCastScalarFunctor<
591-
arg0_t,
592-
ntensors,
593-
decltype(offset_calc),
594-
func_t>(data, dtypes, offset_calc, f));
606+
using functor = LegacyKernelWithCastScalarFunctor<
607+
arg0_t,
608+
ntensors,
609+
decltype(offset_calc),
610+
func_t>;
611+
launch_legacy_group_range_kernel<unroll_factor, functor, true>(
612+
numel, functor(data, dtypes, offset_calc, f));
595613
}
596614
}
597615

src/comm/SYCLHelpers.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#pragma once
22

33
#include <comm/Scalar.h>
4+
#include <sycl/ext/intel/experimental/grf_size_properties.hpp>
5+
#include <sycl/ext/oneapi/experimental/enqueue_functions.hpp>
6+
#include <sycl/ext/oneapi/properties/properties.hpp>
47
#include <sycl/sycl.hpp>
58

69
// sycl access address space
@@ -139,6 +142,52 @@ sycl_kernel_submit(
139142
q.submit(cgf);
140143
}
141144

145+
template <typename ker_t>
146+
struct SmallGRF;
147+
148+
template <typename ker_t>
149+
static inline typename std::enable_if<
150+
std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>,
151+
void>::type
152+
sycl_kernel_submit_small_grf(
153+
int64_t global_range,
154+
int64_t local_range,
155+
::sycl::queue q,
156+
ker_t ker) {
157+
::sycl::ext::oneapi::experimental::properties kernel_props{
158+
::sycl::ext::intel::experimental::grf_size<128>};
159+
auto range = ::sycl::nd_range<1>(
160+
::sycl::range<1>(global_range), ::sycl::range<1>(local_range));
161+
::sycl::ext::oneapi::experimental::launch_config config(range, kernel_props);
162+
auto cgf = [&](::sycl::handler& cgh) {
163+
ker.sycl_ker_config_convention(cgh);
164+
::sycl::ext::oneapi::experimental::nd_launch<SmallGRF<ker_t>>(
165+
cgh, config, ker);
166+
};
167+
q.submit(cgf);
168+
}
169+
170+
template <typename ker_t>
171+
static inline typename std::enable_if<
172+
!std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>,
173+
void>::type
174+
sycl_kernel_submit_small_grf(
175+
int64_t global_range,
176+
int64_t local_range,
177+
::sycl::queue q,
178+
ker_t ker) {
179+
::sycl::ext::oneapi::experimental::properties kernel_props{
180+
::sycl::ext::intel::experimental::grf_size<128>};
181+
auto range = ::sycl::nd_range<1>(
182+
::sycl::range<1>(global_range), ::sycl::range<1>(local_range));
183+
::sycl::ext::oneapi::experimental::launch_config config(range, kernel_props);
184+
auto cgf = [&](::sycl::handler& cgh) {
185+
::sycl::ext::oneapi::experimental::nd_launch<SmallGRF<ker_t>>(
186+
cgh, config, ker);
187+
};
188+
q.submit(cgf);
189+
}
190+
142191
#define SYCL_KERNEL_STRING(var, str) \
143192
static const __attribute__((opencl_constant)) char var[] = str;
144193
#define SYCL_KERNEL_PRINTF sycl::ext::oneapi::experimental::printf

0 commit comments

Comments
 (0)