@@ -12,10 +12,18 @@ using Tensor = at::Tensor;
1212
1313namespace fbgemm_gpu {
1414
15+ #ifdef USE_ROCM
16+ // The wave size is forced to be 32 on ROCm devices in favor
17+ // of granularity losses reduction.
18+ constexpr int EMULATED_WARP_SIZE = 32 ;
19+ #else
20+ constexpr int EMULATED_WARP_SIZE = kWarpSize ;
21+ #endif
22+
1523// TODO: Update UNROLL_FACTOR
1624constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
1725constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
26+ GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE ;
1927
2028// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2129constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4351 const int64_t num_work_rows, // number of rows to work on per member
4452 const int64_t group_size) {
4553 const auto total_num_warps = warp_offsets_group[group_size];
54+ int32_t num_cols = 0 ;
55+ int32_t warps_per_row = 0 ;
56+
57+ if constexpr (!USE_VAR_COLS) {
58+ num_cols = num_cols_group[0 ];
59+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
60+ }
61+
4662 for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
4763 warp_id < total_num_warps;
4864 warp_id += gridDim .x * blockDim .y ) {
49- int32_t member_id, member_warp_id, num_cols, warps_per_row;
50- if (USE_VAR_COLS) {
51- __shared__ int member_ids[kMaxThreads / kWarpSize ];
65+ int32_t member_id = 0 ;
66+ int32_t member_warp_id = 0 ;
67+ if constexpr (USE_VAR_COLS) {
68+ __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
5269 if (threadIdx .x == 0 ) {
5370 binary_search_range (
5471 &member_ids[threadIdx .y ],
@@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6380 member_warp_id = warp_id - warp_offsets_group[member_id];
6481 } else {
6582 // All columns are the same
66- num_cols = num_cols_group[0 ];
67- warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
6883 member_id = warp_id / (warps_per_row * num_work_rows);
6984 member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
7085 }
@@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8297#pragma unroll
8398 for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
8499 // Compile time conditional
85- if (USE_INDEX_SELECT) {
100+ if constexpr (USE_INDEX_SELECT) {
86101 output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
87102 } else {
88103 gpuAtomicAddNoReturn (
@@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
113128 at::cuda::OptionalCUDAGuard device_guard (device);
114129
115130 // Partition work based on num_work_rows
116- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
131+ uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE ;
117132 uint32_t max_grid_size =
118133 at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
119134 uint32_t grid_size = std::min (
120135 cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
121136 max_grid_size);
122- dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
137+ dim3 block_size (EMULATED_WARP_SIZE , num_warps_per_threadblock, 1 );
123138
124139#define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
125140 FBGEMM_LAUNCH_KERNEL ( \
0 commit comments