Skip to content

Commit c55d800

Browse files
sryapfacebook-github-bot
authored andcommitted
Refactor jagged_index_select and add impl for CPU (#1586)
Summary: Pull Request resolved: #1586 Changes are as follows: - Update variable names of jagged_index_add_2d for GPU - Use a single autograd function to dispatch the op on CPU and GPU - Add an implementation on CPU using `at::parallel_for` to parallelize work on multiple threads. For jagged_index_add_2d (backward), locks are used instead of atomic add (one lock per row) to manage add conflict between threads. Reviewed By: brad-mengchi Differential Revision: D43111264 fbshipit-source-id: e319a4cd3376dd3e4dd8eaf8d6ffef9f7a390961
1 parent a7ecc59 commit c55d800

File tree

7 files changed

+443
-116
lines changed

7 files changed

+443
-116
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,5 +739,26 @@ std::vector<at::Tensor> group_index_add_cuda(
739739
const int num_output_rows,
740740
const int num_cols,
741741
const int num_groups);
742+
743+
std::vector<at::Tensor> jagged_index_select_2d(
744+
const at::Tensor& values,
745+
const at::Tensor& lengths,
746+
const at::Tensor& indices);
747+
748+
at::Tensor jagged_index_select_2d_forward_cpu(
749+
const at::Tensor& values,
750+
const at::Tensor& indices,
751+
const at::Tensor& input_offsets,
752+
const at::Tensor& output_offsets,
753+
const int64_t num_dense_output_rows);
754+
755+
at::Tensor jagged_index_add_2d_forward_cpu(
756+
const at::Tensor& grad,
757+
const at::Tensor& indices,
758+
const at::Tensor& grad_offsets,
759+
const at::Tensor& output_offsets,
760+
const int64_t num_dense_grad_rows,
761+
const int64_t num_output_rows);
762+
742763
#endif
743764
} // namespace fbgemm_gpu

fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,49 @@ struct StackArray {
330330
false, "unsupported number of jagged dim ", num_jagged_dim); \
331331
} \
332332
});
333+
334+
// TODO: Merge this with the device code
335+
template <typename scalar_t>
336+
void binary_search_range_cpu(
337+
int* found,
338+
const scalar_t* arr,
339+
const scalar_t target,
340+
const int num_entries) {
341+
const int last_entry = num_entries - 1;
342+
int start = 0, end = last_entry;
343+
int found_ = -1;
344+
while (start <= end) {
345+
int mid = start + (end - start) / 2;
346+
scalar_t mid_offset = arr[mid];
347+
if (target == mid_offset) {
348+
if (mid != last_entry && target != arr[last_entry]) {
349+
// Do linear scan in case of duplicate data (We assume that the
350+
// number of duplicates is small. This can we very bad if the
351+
// number of duplicates is large)
352+
for (int i = mid + 1; i < num_entries; i++) {
353+
if (target != arr[i]) {
354+
found_ = i;
355+
break;
356+
}
357+
}
358+
}
359+
break;
360+
} else if (target < mid_offset) {
361+
if (mid == 0) {
362+
found_ = 0;
363+
break;
364+
} else if (mid - 1 >= 0 && target > arr[mid - 1]) {
365+
found_ = mid;
366+
break;
367+
}
368+
end = mid - 1;
369+
} else {
370+
if (mid + 1 <= last_entry && target < arr[mid + 1]) {
371+
found_ = mid + 1;
372+
break;
373+
}
374+
start = mid + 1;
375+
}
376+
}
377+
*found = found_;
378+
}

fbgemm_gpu/src/jagged_tensor_ops.cu

Lines changed: 68 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,12 +2408,29 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel(
24082408
}
24092409
}
24102410
2411-
Tensor jagged_index_select_2d_cuda(
2411+
/// Copy sequences from input jagged tensor based on indices specified in the
2412+
/// indices tensor to an output jagged tensor (host function for dispatching
2413+
/// jagged_index_select_2d_kernel to GPU)
2414+
/// @param values 2D dense value tensor of input jagged tensor
2415+
/// @param indices 1D tensor that contains indices to be selected
2416+
/// from output jagged tensor
2417+
/// @param input_offsets 1D tensor that contains offsets of input
2418+
/// jagged tensor
2419+
/// @param output_offsets 1D tensor that contains offsets of output
2420+
/// jagged tensor
2421+
/// @param num_dense_output_rows The total number of rows in the 2D dense value
2422+
/// tensor of output jagged tensor
2423+
Tensor jagged_index_select_2d_forward_cuda(
24122424
const Tensor& values,
24132425
const Tensor& indices,
24142426
const Tensor& input_offsets,
24152427
const Tensor& output_offsets,
24162428
const int64_t num_dense_output_rows) {
2429+
TENSOR_ON_CUDA_GPU(values);
2430+
TENSOR_ON_CUDA_GPU(indices);
2431+
TENSOR_ON_CUDA_GPU(input_offsets);
2432+
TENSOR_ON_CUDA_GPU(output_offsets);
2433+
24172434
at::cuda::OptionalCUDAGuard device_guard;
24182435
device_guard.set_index(values.get_device());
24192436
@@ -2462,21 +2479,22 @@ Tensor jagged_index_select_2d_cuda(
24622479
template <typename index_t, typename offset_t, typename scalar_t>
24632480
__global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel(
24642481
scalar_t* output,
2465-
const scalar_t* grad,
2466-
const offset_t* grad_offsets,
2482+
const scalar_t* values,
2483+
const offset_t* input_offsets,
24672484
const index_t* indices,
24682485
const offset_t* output_offsets,
2469-
const int64_t num_grad_rows,
2470-
const int64_t num_dense_grad_rows,
2486+
const int64_t num_input_rows,
2487+
const int64_t num_dense_input_rows,
24712488
const int64_t num_cols) {
24722489
__shared__ int smem[1];
2473-
for (offset_t dense_grad_offset = blockIdx.x;
2474-
dense_grad_offset < num_dense_grad_rows;
2475-
dense_grad_offset += gridDim.x) {
2490+
for (offset_t dense_input_offset = blockIdx.x;
2491+
dense_input_offset < num_dense_input_rows;
2492+
dense_input_offset += gridDim.x) {
24762493
// Binary search
24772494
// TODO: use multiple threads to do bin search to reduce number of steps
24782495
if (threadIdx.x == 0) {
2479-
binary_search_range(smem, grad_offsets, dense_grad_offset, num_grad_rows);
2496+
binary_search_range(
2497+
smem, input_offsets, dense_input_offset, num_input_rows);
24802498
}
24812499
__syncthreads();
24822500
@@ -2486,48 +2504,66 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel(
24862504
24872505
// TODO: Can also be obtained during the binary search
24882506
// Relative index position
2489-
const offset_t rel_index =
2490-
dense_grad_offset - (index_pos == 0 ? 0 : grad_offsets[index_pos - 1]);
2507+
const offset_t rel_index = dense_input_offset -
2508+
(index_pos == 0 ? 0 : input_offsets[index_pos - 1]);
24912509
const index_t index = indices[index_pos];
24922510
const offset_t output_offset =
24932511
(index == 0 ? 0 : output_offsets[index - 1]) + rel_index;
24942512
24952513
// Shift buffers
2496-
const scalar_t* grad_ = grad + dense_grad_offset * num_cols;
2514+
const scalar_t* values_ = values + dense_input_offset * num_cols;
24972515
scalar_t* output_ = output + output_offset * num_cols;
24982516
24992517
// TODO: Avoid using atoimcAdd (because it could lead to the numerical
25002518
// indeterminism issue)
25012519
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
2502-
gpuAtomicAdd(&output_[i], grad_[i]);
2520+
gpuAtomicAdd(&output_[i], values_[i]);
25032521
}
25042522
}
25052523
}
25062524
2507-
Tensor jagged_index_add_2d_cuda(
2508-
const Tensor& grad,
2525+
/// Add sequences from input jagged tensor to output jagged tensor based on
2526+
/// indices specified in the indices tensor (host function for dispatching
2527+
/// jagged_index_add_2d_kernel to GPU)
2528+
/// @param values 2D dense value tensor of input jagged tensor
2529+
/// @param indices 1D tensor that contains indices to be added in
2530+
/// output jagged tensor
2531+
/// @param input_offsets 1D tensor that contains offsets of input
2532+
/// jagged tensor
2533+
/// @param output_offsets 1D tensor that contains offsets of output
2534+
/// jagged tensor
2535+
/// @param num_dense_input_rows The total number of rows in the 2D dense value
2536+
/// tensor of input jagged tensor
2537+
/// @param num_output_rows The number of sequences in jagged output tensor
2538+
Tensor jagged_index_add_2d_forward_cuda(
2539+
const Tensor& values,
25092540
const Tensor& indices,
2510-
const Tensor& grad_offsets,
2541+
const Tensor& input_offsets,
25112542
const Tensor& output_offsets,
2512-
const int64_t num_dense_grad_rows,
2543+
const int64_t num_dense_input_rows,
25132544
const int64_t num_output_rows) {
2545+
TENSOR_ON_CUDA_GPU(values);
2546+
TENSOR_ON_CUDA_GPU(indices);
2547+
TENSOR_ON_CUDA_GPU(input_offsets);
2548+
TENSOR_ON_CUDA_GPU(output_offsets);
2549+
25142550
at::cuda::OptionalCUDAGuard device_guard;
2515-
device_guard.set_index(grad.get_device());
2551+
device_guard.set_index(values.get_device());
25162552
2517-
auto num_cols = grad.size(1);
2518-
const int64_t num_grad_rows = indices.numel();
2553+
auto num_cols = values.size(1);
2554+
const int64_t num_input_rows = indices.numel();
25192555
25202556
const int64_t max_num_blocks = 1024; // Arbitrarily set to this number of now
25212557
const int64_t max_num_threads = kMaxThreads;
2522-
const int64_t num_blocks = std::min(max_num_blocks, num_dense_grad_rows);
2558+
const int64_t num_blocks = std::min(max_num_blocks, num_dense_input_rows);
25232559
const int64_t num_threads = std::min(max_num_threads, num_cols);
2524-
Tensor output = at::zeros({num_output_rows, num_cols}, grad.options());
2560+
Tensor output = at::zeros({num_output_rows, num_cols}, values.options());
25252561
25262562
if (num_blocks > 0) {
25272563
AT_DISPATCH_ALL_TYPES_AND2(
25282564
at::ScalarType::Half,
25292565
at::ScalarType::BFloat16,
2530-
grad.scalar_type(),
2566+
values.scalar_type(),
25312567
"jagged_index_add_2d_kernel_wrapper_1",
25322568
[&] {
25332569
AT_DISPATCH_INDEX_TYPES(
@@ -2540,12 +2576,12 @@ Tensor jagged_index_add_2d_cuda(
25402576
0,
25412577
at::cuda::getCurrentCUDAStream()>>>(
25422578
output.data_ptr<scalar_t>(),
2543-
grad.data_ptr<scalar_t>(),
2544-
grad_offsets.data_ptr<int64_t>(),
2579+
values.data_ptr<scalar_t>(),
2580+
input_offsets.data_ptr<int64_t>(),
25452581
indices.data_ptr<index_t>(),
25462582
output_offsets.data_ptr<int64_t>(),
2547-
num_grad_rows,
2548-
num_dense_grad_rows,
2583+
num_input_rows,
2584+
num_dense_input_rows,
25492585
num_cols);
25502586
C10_CUDA_KERNEL_LAUNCH_CHECK();
25512587
});
@@ -2555,84 +2591,6 @@ Tensor jagged_index_add_2d_cuda(
25552591
return output;
25562592
}
25572593
2558-
class JaggedIndexSelect2dGPUOp
2559-
: public torch::autograd::Function<JaggedIndexSelect2dGPUOp> {
2560-
public:
2561-
static torch::autograd::variable_list forward(
2562-
torch::autograd::AutogradContext* ctx,
2563-
const Tensor& values,
2564-
const Tensor& lengths,
2565-
const Tensor& indices) {
2566-
TENSOR_ON_CUDA_GPU(lengths);
2567-
TENSOR_ON_CUDA_GPU(values);
2568-
TENSOR_ON_CUDA_GPU(indices);
2569-
TENSORS_ON_SAME_DEVICE(lengths, indices);
2570-
TENSORS_ON_SAME_DEVICE(values, indices);
2571-
2572-
Tensor output_lengths = at::index_select(lengths, 0, indices);
2573-
Tensor output_offsets = output_lengths.cumsum(0);
2574-
Tensor input_offsets = lengths.cumsum(0);
2575-
2576-
// TODO: Try to not do D->H transfer
2577-
// The challenge here is num_dense_output_rows is needed for allocating the
2578-
// output buffer
2579-
int64_t num_dense_output_rows =
2580-
output_offsets[output_offsets.numel() - 1].item<int64_t>();
2581-
2582-
ctx->save_for_backward({indices, output_offsets, input_offsets});
2583-
ctx->saved_data["num_dense_grad_rows"] = num_dense_output_rows;
2584-
ctx->saved_data["num_input_rows"] = values.size(0);
2585-
2586-
return {
2587-
jagged_index_select_2d_cuda(
2588-
values,
2589-
indices,
2590-
input_offsets,
2591-
output_offsets,
2592-
num_dense_output_rows),
2593-
output_lengths};
2594-
}
2595-
2596-
static torch::autograd::variable_list backward(
2597-
torch::autograd::AutogradContext* ctx,
2598-
torch::autograd::variable_list grad_outputs) {
2599-
TORCH_CHECK(grad_outputs.size() == 2);
2600-
TENSOR_ON_CUDA_GPU(grad_outputs[0]);
2601-
2602-
const auto saved = ctx->get_saved_variables();
2603-
auto savedItr = std::begin(saved);
2604-
Tensor indices = *savedItr++;
2605-
Tensor grad_offsets = *savedItr++;
2606-
Tensor output_offsets = *savedItr++;
2607-
2608-
Tensor grad = grad_outputs[0];
2609-
TENSORS_ON_SAME_DEVICE(grad, indices);
2610-
2611-
int64_t num_dense_grad_rows =
2612-
ctx->saved_data["num_dense_grad_rows"].toInt();
2613-
int64_t num_output_rows = ctx->saved_data["num_input_rows"].toInt();
2614-
2615-
return {
2616-
jagged_index_add_2d_cuda(
2617-
grad,
2618-
indices,
2619-
grad_offsets,
2620-
output_offsets,
2621-
num_dense_grad_rows,
2622-
num_output_rows),
2623-
torch::autograd::Variable(), // lengths
2624-
torch::autograd::Variable() // indices
2625-
};
2626-
}
2627-
};
2628-
2629-
std::vector<Tensor> jagged_index_select_2d_gpu(
2630-
const Tensor& values,
2631-
const Tensor& lengths,
2632-
const Tensor& indices) {
2633-
return JaggedIndexSelect2dGPUOp::apply(values, lengths, indices);
2634-
}
2635-
26362594
class StackedJagged2DToDenseGPUOp
26372595
: public torch::autograd::Function<StackedJagged2DToDenseGPUOp> {
26382596
public:
@@ -3155,7 +3113,11 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
31553113
"batched_dense_vec_jagged_2d_mul_backward",
31563114
fbgemm_gpu::batched_dense_vec_jagged_2d_mul_backward);
31573115
DISPATCH_TO_CUDA(
3158-
"jagged_index_select", fbgemm_gpu::jagged_index_select_2d_gpu);
3116+
"jagged_index_select_2d_forward",
3117+
fbgemm_gpu::jagged_index_select_2d_forward_cuda);
3118+
DISPATCH_TO_CUDA(
3119+
"jagged_index_add_2d_forward",
3120+
fbgemm_gpu::jagged_index_add_2d_forward_cuda);
31593121
DISPATCH_TO_CUDA("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense);
31603122
DISPATCH_TO_CUDA("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
31613123
DISPATCH_TO_CUDA(

0 commit comments

Comments
 (0)