From 23e13e34897ab2a8dbc81922a924311b399b10f2 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 3 Nov 2025 21:45:40 -0800 Subject: [PATCH] group_index_select_or_add_2d_kernel forward pass optimization (#5080) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2087 This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Differential Revision: D86135611 Pulled By: q10 --- .../src/sparse_ops/sparse_group_index.cu | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index c1ac40dea6..96c57cde68 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,10 +12,18 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +#ifdef USE_ROCM +// The wave size is forced to be 32 on ROCm devices in favor +// of granularity losses reduction. +constexpr int EMULATED_WARP_SIZE = 32; +#else +constexpr int EMULATED_WARP_SIZE = kWarpSize; +#endif + // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; + GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = @@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t member_id, member_warp_id, num_cols, warps_per_row; - if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } @@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional - if (USE_INDEX_SELECT) { + if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( @@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); + dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \