@@ -2408,12 +2408,29 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel(
24082408 }
24092409}
24102410
2411- Tensor jagged_index_select_2d_cuda (
2411+ // / Copy sequences from input jagged tensor based on indices specified in the
2412+ // / indices tensor to an output jagged tensor (host function for dispatching
2413+ // / jagged_index_select_2d_kernel to GPU)
2414+ // / @param values 2D dense value tensor of input jagged tensor
2415+ // / @param indices 1D tensor that contains indices to be selected
2416+ // / from output jagged tensor
2417+ // / @param input_offsets 1D tensor that contains offsets of input
2418+ // / jagged tensor
2419+ // / @param output_offsets 1D tensor that contains offsets of output
2420+ // / jagged tensor
2421+ // / @param num_dense_output_rows The total number of rows in the 2D dense value
2422+ // / tensor of output jagged tensor
2423+ Tensor jagged_index_select_2d_forward_cuda (
24122424 const Tensor& values,
24132425 const Tensor& indices,
24142426 const Tensor& input_offsets,
24152427 const Tensor& output_offsets,
24162428 const int64_t num_dense_output_rows) {
2429+ TENSOR_ON_CUDA_GPU (values);
2430+ TENSOR_ON_CUDA_GPU (indices);
2431+ TENSOR_ON_CUDA_GPU (input_offsets);
2432+ TENSOR_ON_CUDA_GPU (output_offsets);
2433+
24172434 at::cuda::OptionalCUDAGuard device_guard;
24182435 device_guard.set_index (values.get_device ());
24192436
@@ -2462,21 +2479,22 @@ Tensor jagged_index_select_2d_cuda(
24622479template <typename index_t , typename offset_t , typename scalar_t >
24632480__global__ __launch_bounds__ (kMaxThreads ) void jagged_index_add_2d_kernel(
24642481 scalar_t * output,
2465- const scalar_t * grad ,
2466- const offset_t * grad_offsets ,
2482+ const scalar_t * values ,
2483+ const offset_t * input_offsets ,
24672484 const index_t * indices,
24682485 const offset_t * output_offsets,
2469- const int64_t num_grad_rows ,
2470- const int64_t num_dense_grad_rows ,
2486+ const int64_t num_input_rows ,
2487+ const int64_t num_dense_input_rows ,
24712488 const int64_t num_cols) {
24722489 __shared__ int smem[1 ];
2473- for (offset_t dense_grad_offset = blockIdx .x ;
2474- dense_grad_offset < num_dense_grad_rows ;
2475- dense_grad_offset += gridDim .x ) {
2490+ for (offset_t dense_input_offset = blockIdx .x ;
2491+ dense_input_offset < num_dense_input_rows ;
2492+ dense_input_offset += gridDim .x ) {
24762493 // Binary search
24772494 // TODO: use multiple threads to do bin search to reduce number of steps
24782495 if (threadIdx .x == 0 ) {
2479- binary_search_range (smem, grad_offsets, dense_grad_offset, num_grad_rows);
2496+ binary_search_range (
2497+ smem, input_offsets, dense_input_offset, num_input_rows);
24802498 }
24812499 __syncthreads ();
24822500
@@ -2486,48 +2504,66 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel(
24862504
24872505 // TODO: Can also be obtained during the binary search
24882506 // Relative index position
2489- const offset_t rel_index =
2490- dense_grad_offset - (index_pos == 0 ? 0 : grad_offsets [index_pos - 1 ]);
2507+ const offset_t rel_index = dense_input_offset -
2508+ (index_pos == 0 ? 0 : input_offsets [index_pos - 1 ]);
24912509 const index_t index = indices[index_pos];
24922510 const offset_t output_offset =
24932511 (index == 0 ? 0 : output_offsets[index - 1 ]) + rel_index;
24942512
24952513 // Shift buffers
2496- const scalar_t * grad_ = grad + dense_grad_offset * num_cols;
2514+ const scalar_t * values_ = values + dense_input_offset * num_cols;
24972515 scalar_t * output_ = output + output_offset * num_cols;
24982516
24992517 // TODO: Avoid using atoimcAdd (because it could lead to the numerical
25002518 // indeterminism issue)
25012519 for (int i = threadIdx .x ; i < num_cols; i += blockDim .x ) {
2502- gpuAtomicAdd (&output_[i], grad_ [i]);
2520+ gpuAtomicAdd (&output_[i], values_ [i]);
25032521 }
25042522 }
25052523}
25062524
2507- Tensor jagged_index_add_2d_cuda (
2508- const Tensor& grad,
2525+ // / Add sequences from input jagged tensor to output jagged tensor based on
2526+ // / indices specified in the indices tensor (host function for dispatching
2527+ // / jagged_index_add_2d_kernel to GPU)
2528+ // / @param values 2D dense value tensor of input jagged tensor
2529+ // / @param indices 1D tensor that contains indices to be added in
2530+ // / output jagged tensor
2531+ // / @param input_offsets 1D tensor that contains offsets of input
2532+ // / jagged tensor
2533+ // / @param output_offsets 1D tensor that contains offsets of output
2534+ // / jagged tensor
2535+ // / @param num_dense_input_rows The total number of rows in the 2D dense value
2536+ // / tensor of input jagged tensor
2537+ // / @param num_output_rows The number of sequences in jagged output tensor
2538+ Tensor jagged_index_add_2d_forward_cuda (
2539+ const Tensor& values,
25092540 const Tensor& indices,
2510- const Tensor& grad_offsets ,
2541+ const Tensor& input_offsets ,
25112542 const Tensor& output_offsets,
2512- const int64_t num_dense_grad_rows ,
2543+ const int64_t num_dense_input_rows ,
25132544 const int64_t num_output_rows) {
2545+ TENSOR_ON_CUDA_GPU (values);
2546+ TENSOR_ON_CUDA_GPU (indices);
2547+ TENSOR_ON_CUDA_GPU (input_offsets);
2548+ TENSOR_ON_CUDA_GPU (output_offsets);
2549+
25142550 at::cuda::OptionalCUDAGuard device_guard;
2515- device_guard.set_index (grad .get_device ());
2551+ device_guard.set_index (values .get_device ());
25162552
2517- auto num_cols = grad .size (1 );
2518- const int64_t num_grad_rows = indices.numel ();
2553+ auto num_cols = values .size (1 );
2554+ const int64_t num_input_rows = indices.numel ();
25192555
25202556 const int64_t max_num_blocks = 1024 ; // Arbitrarily set to this number of now
25212557 const int64_t max_num_threads = kMaxThreads ;
2522- const int64_t num_blocks = std::min (max_num_blocks, num_dense_grad_rows );
2558+ const int64_t num_blocks = std::min (max_num_blocks, num_dense_input_rows );
25232559 const int64_t num_threads = std::min (max_num_threads, num_cols);
2524- Tensor output = at::zeros ({num_output_rows, num_cols}, grad .options ());
2560+ Tensor output = at::zeros ({num_output_rows, num_cols}, values .options ());
25252561
25262562 if (num_blocks > 0 ) {
25272563 AT_DISPATCH_ALL_TYPES_AND2 (
25282564 at::ScalarType::Half,
25292565 at::ScalarType::BFloat16,
2530- grad .scalar_type (),
2566+ values .scalar_type (),
25312567 " jagged_index_add_2d_kernel_wrapper_1" ,
25322568 [&] {
25332569 AT_DISPATCH_INDEX_TYPES (
@@ -2540,12 +2576,12 @@ Tensor jagged_index_add_2d_cuda(
25402576 0,
25412577 at::cuda::getCurrentCUDAStream()>>>(
25422578 output.data_ptr<scalar_t >(),
2543- grad .data_ptr<scalar_t>(),
2544- grad_offsets .data_ptr<int64_t>(),
2579+ values .data_ptr<scalar_t>(),
2580+ input_offsets .data_ptr<int64_t>(),
25452581 indices.data_ptr<index_t>(),
25462582 output_offsets.data_ptr<int64_t>(),
2547- num_grad_rows ,
2548- num_dense_grad_rows ,
2583+ num_input_rows ,
2584+ num_dense_input_rows ,
25492585 num_cols);
25502586 C10_CUDA_KERNEL_LAUNCH_CHECK ();
25512587 });
@@ -2555,84 +2591,6 @@ Tensor jagged_index_add_2d_cuda(
25552591 return output;
25562592}
25572593
2558- class JaggedIndexSelect2dGPUOp
2559- : public torch::autograd::Function<JaggedIndexSelect2dGPUOp> {
2560- public:
2561- static torch::autograd::variable_list forward (
2562- torch::autograd::AutogradContext* ctx,
2563- const Tensor& values,
2564- const Tensor& lengths,
2565- const Tensor& indices) {
2566- TENSOR_ON_CUDA_GPU (lengths);
2567- TENSOR_ON_CUDA_GPU (values);
2568- TENSOR_ON_CUDA_GPU (indices);
2569- TENSORS_ON_SAME_DEVICE (lengths, indices);
2570- TENSORS_ON_SAME_DEVICE (values, indices);
2571-
2572- Tensor output_lengths = at::index_select (lengths, 0 , indices);
2573- Tensor output_offsets = output_lengths.cumsum (0 );
2574- Tensor input_offsets = lengths.cumsum (0 );
2575-
2576- // TODO: Try to not do D->H transfer
2577- // The challenge here is num_dense_output_rows is needed for allocating the
2578- // output buffer
2579- int64_t num_dense_output_rows =
2580- output_offsets[output_offsets.numel () - 1 ].item <int64_t >();
2581-
2582- ctx->save_for_backward ({indices, output_offsets, input_offsets});
2583- ctx->saved_data [" num_dense_grad_rows" ] = num_dense_output_rows;
2584- ctx->saved_data [" num_input_rows" ] = values.size (0 );
2585-
2586- return {
2587- jagged_index_select_2d_cuda (
2588- values,
2589- indices,
2590- input_offsets,
2591- output_offsets,
2592- num_dense_output_rows),
2593- output_lengths};
2594- }
2595-
2596- static torch::autograd::variable_list backward (
2597- torch::autograd::AutogradContext* ctx,
2598- torch::autograd::variable_list grad_outputs) {
2599- TORCH_CHECK (grad_outputs.size () == 2 );
2600- TENSOR_ON_CUDA_GPU (grad_outputs[0 ]);
2601-
2602- const auto saved = ctx->get_saved_variables ();
2603- auto savedItr = std::begin (saved);
2604- Tensor indices = *savedItr++;
2605- Tensor grad_offsets = *savedItr++;
2606- Tensor output_offsets = *savedItr++;
2607-
2608- Tensor grad = grad_outputs[0 ];
2609- TENSORS_ON_SAME_DEVICE (grad, indices);
2610-
2611- int64_t num_dense_grad_rows =
2612- ctx->saved_data [" num_dense_grad_rows" ].toInt ();
2613- int64_t num_output_rows = ctx->saved_data [" num_input_rows" ].toInt ();
2614-
2615- return {
2616- jagged_index_add_2d_cuda (
2617- grad,
2618- indices,
2619- grad_offsets,
2620- output_offsets,
2621- num_dense_grad_rows,
2622- num_output_rows),
2623- torch::autograd::Variable (), // lengths
2624- torch::autograd::Variable () // indices
2625- };
2626- }
2627- };
2628-
2629- std::vector<Tensor> jagged_index_select_2d_gpu (
2630- const Tensor& values,
2631- const Tensor& lengths,
2632- const Tensor& indices) {
2633- return JaggedIndexSelect2dGPUOp::apply (values, lengths, indices);
2634- }
2635-
26362594class StackedJagged2DToDenseGPUOp
26372595 : public torch::autograd::Function<StackedJagged2DToDenseGPUOp> {
26382596 public:
@@ -3155,7 +3113,11 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
31553113 " batched_dense_vec_jagged_2d_mul_backward" ,
31563114 fbgemm_gpu::batched_dense_vec_jagged_2d_mul_backward);
31573115 DISPATCH_TO_CUDA (
3158- " jagged_index_select" , fbgemm_gpu::jagged_index_select_2d_gpu);
3116+ " jagged_index_select_2d_forward" ,
3117+ fbgemm_gpu::jagged_index_select_2d_forward_cuda);
3118+ DISPATCH_TO_CUDA (
3119+ " jagged_index_add_2d_forward" ,
3120+ fbgemm_gpu::jagged_index_add_2d_forward_cuda);
31593121 DISPATCH_TO_CUDA (" jagged_1d_to_dense" , fbgemm_gpu::jagged_1d_to_dense);
31603122 DISPATCH_TO_CUDA (" jagged_2d_to_dense" , fbgemm_gpu::jagged_2d_to_dense);
31613123 DISPATCH_TO_CUDA (
0 commit comments