Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

#ifdef USE_ROCM
// The wave size is forced to be 32 on ROCm devices in favor
// of granularity losses reduction.
constexpr int EMULATED_WARP_SIZE = 32;
#else
constexpr int EMULATED_WARP_SIZE = kWarpSize;
#endif

// TODO: Update UNROLL_FACTOR
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;

// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
Expand Down Expand Up @@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
const int64_t num_work_rows, // number of rows to work on per member
const int64_t group_size) {
const auto total_num_warps = warp_offsets_group[group_size];
int32_t num_cols = 0;
int32_t warps_per_row = 0;

if constexpr (!USE_VAR_COLS) {
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
}

for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
int32_t member_id, member_warp_id, num_cols, warps_per_row;
if (USE_VAR_COLS) {
__shared__ int member_ids[kMaxThreads / kWarpSize];
int32_t member_id = 0;
int32_t member_warp_id = 0;
if constexpr (USE_VAR_COLS) {
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
if (threadIdx.x == 0) {
binary_search_range(
&member_ids[threadIdx.y],
Expand All @@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
member_warp_id = warp_id - warp_offsets_group[member_id];
} else {
// All columns are the same
num_cols = num_cols_group[0];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
}
Expand All @@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if (USE_INDEX_SELECT) {
if constexpr (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
Expand Down Expand Up @@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
at::cuda::OptionalCUDAGuard device_guard(device);

// Partition work based on num_work_rows
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
max_grid_size);
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);

#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
FBGEMM_LAUNCH_KERNEL( \
Expand Down
Loading