@@ -12,10 +12,19 @@ using Tensor = at::Tensor;
1212
1313namespace fbgemm_gpu {
1414
15+ // The wave size is forced to be 32 on ROCm devices in favor
16+ // of granularity losses reduction.
17+
18+ #ifdef USE_ROCM
19+ constexpr int EMULATED_WARP_SIZE = 32 ;
20+ #else
21+ constexpr int EMULATED_WARP_SIZE = kWarpSize ;
22+ #endif
23+
1524// TODO: Update UNROLL_FACTOR
1625constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
1726constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
27+ GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE ;
1928
2029// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2130constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -43,12 +52,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4352 const int64_t num_work_rows, // number of rows to work on per member
4453 const int64_t group_size) {
4554 const auto total_num_warps = warp_offsets_group[group_size];
55+ int32_t num_cols = 0 ;
56+ int32_t warps_per_row = 0 ;
57+
58+ if constexpr (!USE_VAR_COLS) {
59+ num_cols = num_cols_group[0 ];
60+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
61+ }
62+
4663 for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
4764 warp_id < total_num_warps;
4865 warp_id += gridDim .x * blockDim .y ) {
49- int32_t member_id, member_warp_id, num_cols, warps_per_row;
50- if (USE_VAR_COLS) {
51- __shared__ int member_ids[kMaxThreads / kWarpSize ];
66+ int32_t member_id = 0 ;
67+ int32_t member_warp_id = 0 ;
68+ if constexpr (USE_VAR_COLS) {
69+ __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
5270 if (threadIdx .x == 0 ) {
5371 binary_search_range (
5472 &member_ids[threadIdx .y ],
@@ -63,8 +81,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6381 member_warp_id = warp_id - warp_offsets_group[member_id];
6482 } else {
6583 // All columns are the same
66- num_cols = num_cols_group[0 ];
67- warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
6884 member_id = warp_id / (warps_per_row * num_work_rows);
6985 member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
7086 }
@@ -82,7 +98,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8298#pragma unroll
8399 for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
84100 // Compile time conditional
85- if (USE_INDEX_SELECT) {
101+ if constexpr (USE_INDEX_SELECT) {
86102 output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
87103 } else {
88104 gpuAtomicAddNoReturn (
@@ -113,13 +129,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
113129 at::cuda::OptionalCUDAGuard device_guard (device);
114130
115131 // Partition work based on num_work_rows
116- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
132+ uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE ;
117133 uint32_t max_grid_size =
118134 at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
119135 uint32_t grid_size = std::min (
120136 cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
121137 max_grid_size);
122- dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
138+ dim3 block_size (EMULATED_WARP_SIZE , num_warps_per_threadblock, 1 );
123139
124140#define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
125141 FBGEMM_LAUNCH_KERNEL ( \
0 commit comments