Skip to content

Commit 27c4a26

Browse files
avbokovoyfacebook-github-bot
authored andcommitted
group_index_select_or_add_2d_kernel forward pass optimization (#5078)
Summary: This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Differential Revision: D86135611 Pulled By: q10
1 parent ecf2ac9 commit 27c4a26

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,19 @@ using Tensor = at::Tensor;
1212

1313
namespace fbgemm_gpu {
1414

15+
// The wave size is forced to be 32 on ROCm devices in favor
16+
// of granularity losses reduction.
17+
18+
#ifdef USE_ROCM
19+
constexpr int EMULATED_WARP_SIZE = 32;
20+
#else
21+
constexpr int EMULATED_WARP_SIZE = kWarpSize;
22+
#endif
23+
1524
// TODO: Update UNROLL_FACTOR
1625
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
1726
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18-
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
27+
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
1928

2029
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2130
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -43,12 +52,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4352
const int64_t num_work_rows, // number of rows to work on per member
4453
const int64_t group_size) {
4554
const auto total_num_warps = warp_offsets_group[group_size];
55+
int32_t num_cols = 0;
56+
int32_t warps_per_row = 0;
57+
58+
if constexpr (!USE_VAR_COLS) {
59+
num_cols = num_cols_group[0];
60+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
61+
}
62+
4663
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
4764
warp_id < total_num_warps;
4865
warp_id += gridDim.x * blockDim.y) {
49-
int32_t member_id, member_warp_id, num_cols, warps_per_row;
50-
if (USE_VAR_COLS) {
51-
__shared__ int member_ids[kMaxThreads / kWarpSize];
66+
int32_t member_id = 0;
67+
int32_t member_warp_id = 0;
68+
if constexpr (USE_VAR_COLS) {
69+
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
5270
if (threadIdx.x == 0) {
5371
binary_search_range(
5472
&member_ids[threadIdx.y],
@@ -63,8 +81,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6381
member_warp_id = warp_id - warp_offsets_group[member_id];
6482
} else {
6583
// All columns are the same
66-
num_cols = num_cols_group[0];
67-
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
6884
member_id = warp_id / (warps_per_row * num_work_rows);
6985
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
7086
}
@@ -82,7 +98,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8298
#pragma unroll
8399
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
84100
// Compile time conditional
85-
if (USE_INDEX_SELECT) {
101+
if constexpr (USE_INDEX_SELECT) {
86102
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
87103
} else {
88104
gpuAtomicAddNoReturn(
@@ -113,13 +129,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
113129
at::cuda::OptionalCUDAGuard device_guard(device);
114130

115131
// Partition work based on num_work_rows
116-
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
132+
uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
117133
uint32_t max_grid_size =
118134
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
119135
uint32_t grid_size = std::min(
120136
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
121137
max_grid_size);
122-
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
138+
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
123139

124140
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
125141
FBGEMM_LAUNCH_KERNEL( \

0 commit comments

Comments
 (0)