@@ -133,6 +133,69 @@ constexpr auto calc_io_size(){
133
133
#endif
134
134
}
135
135
136
+ #ifndef USE_ROCM
137
+ // To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
138
+ // into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
139
+ // used on sm_90 and sm_100 exclusively.
140
+ template <int vec_size, typename func_t , typename array_t >
141
+ C10_LAUNCH_BOUNDS_1 (num_threads())
142
+ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
143
+ if constexpr (vec_size == 8 ) {
144
+ #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
145
+ using traits = function_traits<func_t >;
146
+ constexpr auto io_size = calc_io_size<func_t >();
147
+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
148
+
149
+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
150
+ // just do a naive unrolled loop
151
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
152
+ auto output_calc = TrivialOffsetCalculator<1 >();
153
+ auto loader = memory::LoadWithoutCast ();
154
+ auto storer = memory::StoreWithoutCast ();
155
+ auto policy = memory::policies::unroll<
156
+ array_t ,
157
+ decltype (input_calc),
158
+ decltype (output_calc),
159
+ memory::LoadWithoutCast,
160
+ memory::StoreWithoutCast,
161
+ elems_per_thread<io_size>()>(
162
+ data, remaining, input_calc, output_calc, loader, storer);
163
+ elementwise_kernel_helper (f, policy);
164
+ } else { // if this block has a full `block_work_size` data to handle, use
165
+ // vectorized memory access
166
+ elementwise_kernel_helper (
167
+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
168
+ }
169
+ #endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
170
+ } else {
171
+ using traits = function_traits<func_t >;
172
+ constexpr auto io_size = calc_io_size<func_t >();
173
+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
174
+
175
+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
176
+ // just do a naive unrolled loop
177
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
178
+ auto output_calc = TrivialOffsetCalculator<1 >();
179
+ auto loader = memory::LoadWithoutCast ();
180
+ auto storer = memory::StoreWithoutCast ();
181
+ auto policy = memory::policies::unroll<
182
+ array_t ,
183
+ decltype (input_calc),
184
+ decltype (output_calc),
185
+ memory::LoadWithoutCast,
186
+ memory::StoreWithoutCast,
187
+ elems_per_thread<io_size>()>(
188
+ data, remaining, input_calc, output_calc, loader, storer);
189
+ elementwise_kernel_helper (f, policy);
190
+ } else { // if this block has a full `block_work_size` data to handle, use
191
+ // vectorized memory access
192
+ elementwise_kernel_helper (
193
+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
194
+ }
195
+ }
196
+ }
197
+
198
+ #else // USE_ROCM
136
199
template <int vec_size, typename func_t , typename array_t >
137
200
C10_LAUNCH_BOUNDS_1 (num_threads())
138
201
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -157,15 +220,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
157
220
elementwise_kernel_helper (f, policy);
158
221
} else { // if this block has a full `block_work_size` data to handle, use
159
222
// vectorized memory access
160
- #ifdef USE_ROCM
161
223
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
162
- #else
163
- constexpr auto optimal_vec_size = vec_size;
164
- #endif
165
224
elementwise_kernel_helper (
166
225
f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
167
226
}
168
227
}
228
+ #endif // USE_ROCM
169
229
170
230
template <
171
231
typename func_t ,
@@ -212,6 +272,11 @@ static inline void launch_vectorized_kernel(
212
272
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
213
273
// that causes some numerical mismatches with uint8 on sm80 and sm90.
214
274
// TODO: Revisit this after CUDA 12.8 update.
275
+ cudaDeviceProp* p = at::cuda::getDeviceProperties (stream.device ().index ());
276
+ const int computeCapability = p->major * 10 + p->minor ;
277
+ if (computeCapability != 90 && computeCapability != 100 ) {
278
+ vec_size = std::min<uint16_t >(vec_size, 4 );
279
+ }
215
280
if constexpr (sizeof (cpp_type) < 2 ) {
216
281
vec_size = std::min<uint16_t >(vec_size, 4 );
217
282
}
0 commit comments