Skip to content

Commit bf48eea

Browse files
reduce compilation size and time (#886)
Summary: Pull Request resolved: #886 This patch does refactoring for FBGEMM to slightly reduce compilation size and time associated with `cub`. 1. Moved an inline function `asynchronous_complete_cumsum()` from `embedding_backward_template_helpers.cuh` to `split_embeddings_utils.cu`, which is the only code that uses this function. 2. Instead of calling a template function `cub::DeviceRadixSort::SortPairs`, call a non-static function in FBGEMM to avoid expanding a template function in every `gen_embedding_backward_*` code. Reviewed By: jspark1105 Differential Revision: D33801456 fbshipit-source-id: 92ebb3369c4fea25d7360bbacaf36476cba54136
1 parent 2711eca commit bf48eea

File tree

4 files changed

+105
-50
lines changed

4 files changed

+105
-50
lines changed

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
797797
auto lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations);
798798
if (lxu_cache_locations.size(0) > 0) {
799799
size_t temp_storage_bytes = 0;
800-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
800+
AT_CUDA_CHECK(radix_sort_pairs(
801801
nullptr,
802802
temp_storage_bytes,
803803
linear_indices.data_ptr<int64_t>(),
@@ -812,7 +812,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
812812
auto temp_storage = at::empty(
813813
{static_cast<int64_t>(temp_storage_bytes)},
814814
indices.options().dtype(at::kByte));
815-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
815+
AT_CUDA_CHECK(radix_sort_pairs(
816816
temp_storage.data_ptr(),
817817
temp_storage_bytes,
818818
linear_indices.data_ptr<int64_t>(),
@@ -838,12 +838,11 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
838838
{% endif %}
839839
"split_embedding_backward_{{ optimizer }}_exact_kernel",
840840
([&] {
841-
842841
{% if weighted %}
843842
auto indice_weights_sorted = at::empty_like(indice_weights);
844843
{
845844
size_t temp_storage_bytes = 0;
846-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
845+
AT_CUDA_CHECK(radix_sort_pairs(
847846
nullptr,
848847
temp_storage_bytes,
849848
linear_indices.data_ptr<int64_t>(),
@@ -863,7 +862,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
863862
auto temp_storage = at::empty(
864863
{static_cast<int64_t>(temp_storage_bytes)},
865864
indices.options().dtype(at::kByte));
866-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
865+
AT_CUDA_CHECK(radix_sort_pairs(
867866
temp_storage.data_ptr(),
868867
temp_storage_bytes,
869868
linear_indices.data_ptr<int64_t>(),

fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// clang-format off
9-
#include "fbgemm_gpu/cub_namespace_prefix.cuh"
10-
#include <cub/device/device_radix_sort.cuh>
11-
#include <cub/device/device_run_length_encode.cuh>
12-
#include <cub/device/device_scan.cuh>
13-
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
14-
// clang-format on
15-
168
#include <ATen/ATen.h>
179
#include <ATen/AccumulateType.h>
1810
#include <ATen/TensorUtils.h>
@@ -33,43 +25,6 @@
3325
#include "fbgemm_cuda_utils.cuh"
3426
#include "sparse_ops_utils.h"
3527

36-
inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {
37-
at::cuda::OptionalCUDAGuard device_guard;
38-
device_guard.set_index(t_in.get_device());
39-
size_t temp_storage_bytes = 0;
40-
TORCH_CHECK(t_in.is_contiguous());
41-
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
42-
// CUB only handles up to INT_MAX elements.
43-
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
44-
TORCH_CHECK(t_in.dim() == 1);
45-
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
46-
t_out[0].zero_();
47-
AT_DISPATCH_INTEGRAL_TYPES(
48-
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", ([&] {
49-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
50-
nullptr,
51-
temp_storage_bytes,
52-
t_in.data_ptr<scalar_t>(),
53-
t_out.data_ptr<scalar_t>() + 1,
54-
t_in.numel(),
55-
at::cuda::getCurrentCUDAStream()));
56-
}));
57-
auto temp_storage = at::empty(
58-
{static_cast<int64_t>(temp_storage_bytes)},
59-
t_in.options().dtype(at::kByte));
60-
AT_DISPATCH_INTEGRAL_TYPES(
61-
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", ([&] {
62-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
63-
temp_storage.data_ptr(),
64-
temp_storage_bytes,
65-
t_in.data_ptr<scalar_t>(),
66-
t_out.data_ptr<scalar_t>() + 1,
67-
t_in.numel(),
68-
at::cuda::getCurrentCUDAStream()));
69-
}));
70-
return t_out;
71-
}
72-
7328
class FixedDivisor {
7429
public:
7530
explicit FixedDivisor(const int32_t d) : d_(d) {

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,27 @@ transpose_embedding_input(
132132
at::Tensor indices,
133133
at::Tensor offsets,
134134
bool nobag = false);
135+
136+
// Use these functions instead of directly calling cub functions
137+
// to reduce code size and compilation time.
138+
// Arguments are the same as cub::DeviceRadixSort::SortPairs
139+
#define DECL_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \
140+
cudaError_t radix_sort_pairs( \
141+
void* d_temp_storage, \
142+
size_t& temp_storage_bytes, \
143+
const KeyT* d_keys_in, \
144+
KeyT* d_keys_out, \
145+
const ValueT* d_values_in, \
146+
ValueT* d_values_out, \
147+
int num_items, \
148+
int begin_bit = 0, \
149+
int end_bit = sizeof(KeyT) * 8, \
150+
cudaStream_t stream = 0, \
151+
bool debug_synchronous = false)
152+
153+
DECL_RADIX_SORT_PAIRS_FN(int64_t, float);
154+
DECL_RADIX_SORT_PAIRS_FN(int64_t, double);
155+
DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
156+
DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
157+
158+
#undef DECL_RADIX_SORT_PAIRS_FN

fbgemm_gpu/src/split_embeddings_utils.cu

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,51 @@
1010
#include <c10/cuda/CUDAStream.h>
1111
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
1212

13+
// clang-format off
14+
#include "fbgemm_gpu/cub_namespace_prefix.cuh"
15+
#include <cub/device/device_radix_sort.cuh>
16+
#include <cub/device/device_run_length_encode.cuh>
17+
#include <cub/device/device_scan.cuh>
18+
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
19+
// clang-format on
20+
21+
inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {
22+
at::cuda::OptionalCUDAGuard device_guard;
23+
device_guard.set_index(t_in.get_device());
24+
size_t temp_storage_bytes = 0;
25+
TORCH_CHECK(t_in.is_contiguous());
26+
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
27+
// CUB only handles up to INT_MAX elements.
28+
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
29+
TORCH_CHECK(t_in.dim() == 1);
30+
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
31+
t_out[0].zero_();
32+
AT_DISPATCH_INTEGRAL_TYPES(
33+
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", ([&] {
34+
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
35+
nullptr,
36+
temp_storage_bytes,
37+
t_in.data_ptr<scalar_t>(),
38+
t_out.data_ptr<scalar_t>() + 1,
39+
t_in.numel(),
40+
at::cuda::getCurrentCUDAStream()));
41+
}));
42+
auto temp_storage = at::empty(
43+
{static_cast<int64_t>(temp_storage_bytes)},
44+
t_in.options().dtype(at::kByte));
45+
AT_DISPATCH_INTEGRAL_TYPES(
46+
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", ([&] {
47+
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
48+
temp_storage.data_ptr(),
49+
temp_storage_bytes,
50+
t_in.data_ptr<scalar_t>(),
51+
t_out.data_ptr<scalar_t>() + 1,
52+
t_in.numel(),
53+
at::cuda::getCurrentCUDAStream()));
54+
}));
55+
return t_out;
56+
}
57+
1358
using Tensor = at::Tensor;
1459

1560
using namespace fbgemm_gpu;
@@ -227,3 +272,35 @@ transpose_embedding_input(
227272
sorted_linear_indices_num_runs,
228273
sorted_linear_indices_cumulative_run_lengths};
229274
}
275+
276+
#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \
277+
cudaError_t radix_sort_pairs( \
278+
void* d_temp_storage, \
279+
size_t& temp_storage_bytes, \
280+
const KeyT* d_keys_in, \
281+
KeyT* d_keys_out, \
282+
const ValueT* d_values_in, \
283+
ValueT* d_values_out, \
284+
int num_items, \
285+
int begin_bit, \
286+
int end_bit, \
287+
cudaStream_t stream, \
288+
bool debug_synchronous) { \
289+
return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
290+
d_temp_storage, \
291+
temp_storage_bytes, \
292+
d_keys_in, \
293+
d_keys_out, \
294+
d_values_in, \
295+
d_values_out, \
296+
num_items, \
297+
begin_bit, \
298+
end_bit, \
299+
stream, \
300+
debug_synchronous); \
301+
}
302+
303+
DEF_RADIX_SORT_PAIRS_FN(int64_t, float);
304+
DEF_RADIX_SORT_PAIRS_FN(int64_t, double);
305+
DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
306+
DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t);

0 commit comments

Comments
 (0)