Skip to content

Commit 26f4937

Browse files
Aidyn-Apytorchbot
authored andcommitted
[ATen][CUDA] Optimize 128 bit vectorization (#148320)
Fixes #147376. As per request: #145746 (review) This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: #148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman (cherry picked from commit 72337bd)
1 parent cd6037e commit 26f4937

File tree

1 file changed

+69
-4
lines changed

1 file changed

+69
-4
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,69 @@ constexpr auto calc_io_size(){
133133
#endif
134134
}
135135

136+
#ifndef USE_ROCM
137+
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
138+
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
139+
// used on sm_90 and sm_100 exclusively.
140+
template <int vec_size, typename func_t, typename array_t>
141+
C10_LAUNCH_BOUNDS_1(num_threads())
142+
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
143+
if constexpr (vec_size == 8) {
144+
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
145+
using traits = function_traits<func_t>;
146+
constexpr auto io_size = calc_io_size<func_t>();
147+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
148+
149+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
150+
// just do a naive unrolled loop
151+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
152+
auto output_calc = TrivialOffsetCalculator<1>();
153+
auto loader = memory::LoadWithoutCast();
154+
auto storer = memory::StoreWithoutCast();
155+
auto policy = memory::policies::unroll<
156+
array_t,
157+
decltype(input_calc),
158+
decltype(output_calc),
159+
memory::LoadWithoutCast,
160+
memory::StoreWithoutCast,
161+
elems_per_thread<io_size>()>(
162+
data, remaining, input_calc, output_calc, loader, storer);
163+
elementwise_kernel_helper(f, policy);
164+
} else { // if this block has a full `block_work_size` data to handle, use
165+
// vectorized memory access
166+
elementwise_kernel_helper(
167+
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
168+
}
169+
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
170+
} else {
171+
using traits = function_traits<func_t>;
172+
constexpr auto io_size = calc_io_size<func_t>();
173+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
174+
175+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
176+
// just do a naive unrolled loop
177+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
178+
auto output_calc = TrivialOffsetCalculator<1>();
179+
auto loader = memory::LoadWithoutCast();
180+
auto storer = memory::StoreWithoutCast();
181+
auto policy = memory::policies::unroll<
182+
array_t,
183+
decltype(input_calc),
184+
decltype(output_calc),
185+
memory::LoadWithoutCast,
186+
memory::StoreWithoutCast,
187+
elems_per_thread<io_size>()>(
188+
data, remaining, input_calc, output_calc, loader, storer);
189+
elementwise_kernel_helper(f, policy);
190+
} else { // if this block has a full `block_work_size` data to handle, use
191+
// vectorized memory access
192+
elementwise_kernel_helper(
193+
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
194+
}
195+
}
196+
}
197+
198+
#else // USE_ROCM
136199
template <int vec_size, typename func_t, typename array_t>
137200
C10_LAUNCH_BOUNDS_1(num_threads())
138201
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -157,15 +220,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
157220
elementwise_kernel_helper(f, policy);
158221
} else { // if this block has a full `block_work_size` data to handle, use
159222
// vectorized memory access
160-
#ifdef USE_ROCM
161223
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
162-
#else
163-
constexpr auto optimal_vec_size = vec_size;
164-
#endif
165224
elementwise_kernel_helper(
166225
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
167226
}
168227
}
228+
#endif // USE_ROCM
169229

170230
template <
171231
typename func_t,
@@ -212,6 +272,11 @@ static inline void launch_vectorized_kernel(
212272
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
213273
// that causes some numerical mismatches with uint8 on sm80 and sm90.
214274
// TODO: Revisit this after CUDA 12.8 update.
275+
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
276+
const int computeCapability = p->major * 10 + p->minor;
277+
if (computeCapability != 90 && computeCapability != 100) {
278+
vec_size = std::min<uint16_t>(vec_size, 4);
279+
}
215280
if constexpr (sizeof(cpp_type) < 2) {
216281
vec_size = std::min<uint16_t>(vec_size, 4);
217282
}

0 commit comments

Comments
 (0)