@@ -22,7 +22,7 @@ void compute_o_dot_do(
2222 const int bidh) {
2323 // The thread index.
2424 constexpr int kBlockM = T::kBlockM ;
25- // constexpr int kBlockN = T::kBlockN;
25+ constexpr int kBlockN = T::kBlockN ;
2626 constexpr int kHeadDim = T::kHeadDim ;
2727 constexpr int kNSGs = T::kNSGs ;
2828 constexpr int SubgroupSize = T::SubgroupSize;
@@ -31,8 +31,8 @@ void compute_o_dot_do(
3131
3232 auto sg = compat::get_nd_item<1 >().get_sub_group ();
3333 auto group = compat::get_nd_item<1 >().get_group ();
34- // auto first_thread_in_sg_idx = sg.get_group_linear_id() *
35- // trait.SubgroupSize;
34+ auto first_thread_in_sg_idx = sg.get_group_linear_id () * trait. SubgroupSize ;
35+
3636 auto bofst = Boffset (param);
3737
3838 const index_t o_offset = bofst.o_offset (bidb, bidh, m_block * kBlockM );
@@ -209,7 +209,7 @@ CUTLASS_DEVICE void apply_mask_causal(
209209 auto sg = compat::get_nd_item<1 >().get_sub_group ();
210210 auto group = compat::get_nd_item<1 >().get_group ();
211211 int sg_local_id = sg.get_local_id ();
212- // int sg_group_id = sg.get_group_id();
212+ int sg_group_id = sg.get_group_id ();
213213 Tensor rC_2d = make_tensor (rC.data (), convert_layout_2d_layout (rC.layout ()));
214214 CUTLASS_PRAGMA_UNROLL
215215 for (int n = 0 ; n < size<1 >(tensor); ++n) {
@@ -371,8 +371,8 @@ void dq_dk_dv_1colblock(
371371 constexpr int kBlockM = Trait::kBlockM ;
372372 constexpr int kBlockN = Trait::kBlockN ;
373373 constexpr bool is_causal = Trait::is_causal;
374- // constexpr int kNSGs = Trait::kNSGs;
375- // constexpr int SubgroupSize = Trait::SubgroupSize;
374+ constexpr int kNSGs = Trait::kNSGs ;
375+ constexpr int SubgroupSize = Trait::SubgroupSize;
376376 auto sg = compat::get_nd_item<1 >().get_sub_group ();
377377 auto group = compat::get_nd_item<1 >().get_group ();
378378 auto first_thread_in_sg_idx = sg.get_group_linear_id () * trait.SubgroupSize ;
@@ -675,7 +675,7 @@ void dq_dk_dv_1colblock(
675675 const int max_m_block = ceil_div (param.seq_len_q , kBlockM );
676676 const int tail_m = param.seq_len_q % kBlockM ;
677677
678- // cutlass::NumericConverter<T, float> converter;
678+ cutlass::NumericConverter<T, float > converter;
679679
680680 // clear accumulator
681681 clear (tdVrdV);
@@ -880,7 +880,7 @@ void convert_dq(
880880 int bidb,
881881 int bidh) {
882882 constexpr int kBlockM = T::kBlockM ;
883- // constexpr int kBlockN = T::kBlockN;
883+ constexpr int kBlockN = T::kBlockN ;
884884 constexpr int kHeadDim = T::kHeadDim ;
885885 using DType = typename T::DType;
886886 using VType = typename T::VType;
0 commit comments