From d2efaa432b92a0de24b7ef22e5da87dd91a5352d Mon Sep 17 00:00:00 2001 From: jichen Date: Wed, 15 Oct 2025 08:43:23 +0000 Subject: [PATCH 1/3] add subwarp opt for rocm warp64 on fwd v2 kernel --- .../forward/embedding_forward_split_kernel_v2_template.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index 42f499c6dd..34ce2c6f13 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -975,6 +975,13 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( else if (tail_warp_size <= 16) { INVOKE_PROCESS_ALL_INDICES(large_Ls, 16, 0x55) } +#if defined(USE_ROCM) + // not sure step mask value to set when group size is 32 + // while use_lxu_cache is false step mask makes no sense + else if (tail_warp_size <= 32 && !use_lxu_cache) { + INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf) + } +#endif else { INVOKE_PROCESS_ALL_INDICES(large_Ls, kWarpSize, 0xf) } From bce492c4f9ae4cb0eeebf2119e30d4648e8342e9 Mon Sep 17 00:00:00 2001 From: jichen Date: Thu, 21 Aug 2025 10:26:35 +0000 Subject: [PATCH 2/3] apply Vec4T on vbe forward --- ...embedding_forward_split_kernel_template.cu | 21 ++----------------- .../embedding_forward_split_template.cu | 5 ----- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index a39d33e391..aada1cdad5 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -84,11 +84,7 @@ using namespace fbgemm_gpu; {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -182,11 +178,7 @@ using namespace fbgemm_gpu; {%- endif %} {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -319,7 +311,7 @@ using namespace fbgemm_gpu; {%- if is_rocm %} {%- if not nobag %} - rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; + Vec4T vals[kManualUnrollLength * kMaxVecsPerThread]; {%- endif %} // Iterate over kThreadGroupSize indices for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) @@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel( #endif // Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes) - // for CUDA devices and 2 at a time for ROCm - {%- if is_rocm %} - constexpr int VEC_WIDTH = 2; - {%- else %} constexpr int VEC_WIDTH = 4; - {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices constexpr int kManualUnrollLength = 4; @@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel( const float inv_L = (mean_pooling && L != 0) ? static_cast(1.0) / L: static_cast(1.0); // Set up the accumulator buffer - {%- if is_rocm %} - rocm::Vec2T accumulators[kMaxVecsPerThread]; - {%- else %} Vec4T accumulators[kMaxVecsPerThread]; {%- endif %} - {%- endif %} {%- if dense %} {{ embedding_pool_or_store("NULL") }} @@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endmacro %} {%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} - {%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %} + {%- set max_vecs_per_thread = kMaxVecsPerThread %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 6574bda45e..a2956a8a3c 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -702,12 +702,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward - {%- if is_rocm %} - // Account for Vec2 load for ROCm - constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread; - {%- else %} constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {%- endif %} const auto grid = min( div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize), From f76eb7a8301212c0f22bd6aa1aed2622dd8bb98b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Thu, 28 Aug 2025 18:03:27 +0000 Subject: [PATCH 3/3] added rocm guard on wg size change on v2 fwd --- .../training/forward/embedding_forward_split_template.cu | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index a2956a8a3c..f9fafd201c 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -776,9 +776,14 @@ batch_index_select_dim0_codegen_forward_cuda( // if (!is_experimental) } else { // Allocate num warps per table based on max_D + const int num_warps_per_table = B * div_round_up(max_D, kWarpSize * 4); - const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; - + #ifdef USE_ROCM + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / (kWarpSize * 2); + #else + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; + #endif + const auto kernel_func = (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< emb_t, cache_t, output_t, index_t, true>