@@ -39,23 +39,25 @@ def chunk_kda_fwd_kernel_intra_sub_inter(
3939 K : tl .constexpr ,
4040 BT : tl .constexpr ,
4141 BC : tl .constexpr ,
42+ BC2 : tl .constexpr ,
4243 BK : tl .constexpr ,
4344 NC : tl .constexpr ,
4445 IS_VARLEN : tl .constexpr ,
4546):
46- i_t , i_c , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
47+ i_i , i_t , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
4748 i_b , i_h = i_bh // H , i_bh % H
48- i_i , i_j = i_c // NC , i_c % NC
4949 if IS_VARLEN :
5050 i_n , i_t = tl .load (chunk_indices + i_t * 2 ).to (tl .int32 ), tl .load (chunk_indices + i_t * 2 + 1 ).to (tl .int32 )
5151 bos , eos = tl .load (cu_seqlens + i_n ).to (tl .int32 ), tl .load (cu_seqlens + i_n + 1 ).to (tl .int32 )
5252 T = eos - bos
5353 else :
5454 bos , eos = i_b * T , i_b * T + T
5555
56- if i_t * BT + i_i * BC >= T :
57- return
58- if i_i <= i_j :
56+ tl .static_assert (NC <= 4 , "This kernel is specialized for NC <= 4" )
57+
58+ i_ti = i_t * BT + i_i * BC
59+ i_tn = i_ti + BC2
60+ if i_ti >= T :
5961 return
6062
6163 q += (bos * H + i_h ) * K
@@ -64,40 +66,95 @@ def chunk_kda_fwd_kernel_intra_sub_inter(
6466 Aqk += (bos * H + i_h ) * BT
6567 Akk += (bos * H + i_h ) * BT
6668
67- p_b = tl .make_block_ptr (beta + bos * H + i_h , (T ,), (H ,), (i_t * BT + i_i * BC ,), (BC ,), (0 ,))
69+ p_b = tl .make_block_ptr (beta + bos * H + i_h , (T ,), (H ,), (i_ti ,), (BC ,), (0 ,))
6870 b_b = tl .load (p_b , boundary_check = (0 ,))
6971
7072 b_Aqk = tl .zeros ([BC , BC ], dtype = tl .float32 )
7173 b_Akk = tl .zeros ([BC , BC ], dtype = tl .float32 )
74+
75+ b_Aqk0 = tl .zeros ([BC , BC ], dtype = tl .float32 )
76+ b_Akk0 = tl .zeros ([BC , BC ], dtype = tl .float32 )
77+ b_Aqk1 = tl .zeros ([BC , BC ], dtype = tl .float32 )
78+ b_Akk1 = tl .zeros ([BC , BC ], dtype = tl .float32 )
79+ b_Aqk2 = tl .zeros ([BC , BC ], dtype = tl .float32 )
80+ b_Akk2 = tl .zeros ([BC , BC ], dtype = tl .float32 )
7281 for i_k in range (tl .cdiv (K , BK )):
73- p_q = tl .make_block_ptr (q , (T , K ), (H * K , 1 ), (i_t * BT + i_i * BC , i_k * BK ), (BC , BK ), (1 , 0 ))
82+ p_q = tl .make_block_ptr (q , (T , K ), (H * K , 1 ), (i_ti , i_k * BK ), (BC , BK ), (1 , 0 ))
83+ p_k = tl .make_block_ptr (k , (T , K ), (H * K , 1 ), (i_ti , i_k * BK ), (BC , BK ), (1 , 0 ))
84+ p_g = tl .make_block_ptr (g , (T , K ), (H * K , 1 ), (i_ti , i_k * BK ), (BC , BK ), (1 , 0 ))
7485 o_k = i_k * BK + tl .arange (0 , BK )
7586 m_k = o_k < K
76- # [BK,]
77- b_gn = tl .load (g + (i_t * BT + i_i * BC ) * H * K + o_k , mask = m_k , other = 0 )
87+
7888 # [BC, BK]
79- p_g = tl .make_block_ptr (g , (T , K ), (H * K , 1 ), (i_t * BT + i_i * BC , i_k * BK ), (BC , BK ), (1 , 0 ))
80- p_k = tl .make_block_ptr (k , (T , K ), (H * K , 1 ), (i_t * BT + i_i * BC , i_k * BK ), (BC , BK ), (1 , 0 ))
81- b_kt = tl .make_block_ptr (k , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + i_j * BC ), (BK , BC ), (0 , 1 ))
82- p_gk = tl .make_block_ptr (g , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + i_j * BC ), (BK , BC ), (0 , 1 ))
89+ b_q = tl .load (p_q , boundary_check = (0 , 1 ))
90+ b_k = tl .load (p_k , boundary_check = (0 , 1 ))
8391 b_g = tl .load (p_g , boundary_check = (0 , 1 ))
84- b_k = tl .load (p_k , boundary_check = (0 , 1 )) * exp (b_g - b_gn [None , :])
85- b_gk = tl .load (p_gk , boundary_check = (0 , 1 ))
86- b_kt = tl .load (b_kt , boundary_check = (0 , 1 ))
87- # [BC, BC]
88- b_ktg = b_kt * exp (b_gn [:, None ] - b_gk )
89- b_Akk += tl .dot (b_k , b_ktg )
92+ # [BK,]
93+ b_gn = tl .load (g + i_ti * H * K + o_k , mask = m_k , other = 0 )
94+ # [BC, BK]
95+ b_gqk = exp (b_g - b_gn [None , :])
96+ b_qg = b_q * b_gqk
97+ b_kg = b_k * b_gqk
98+ if i_i > 0 :
99+ p_kt0 = tl .make_block_ptr (k , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT ), (BK , BC ), (0 , 1 ))
100+ p_gk0 = tl .make_block_ptr (g , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT ), (BK , BC ), (0 , 1 ))
101+ b_kt0 = tl .load (p_kt0 , boundary_check = (0 , 1 ))
102+ b_gk0 = tl .load (p_gk0 , boundary_check = (0 , 1 ))
103+ b_ktg0 = b_kt0 * exp (b_gn [:, None ] - b_gk0 )
104+ b_Aqk0 += tl .dot (b_qg , b_ktg0 )
105+ b_Akk0 += tl .dot (b_kg , b_ktg0 )
106+ if i_i > 1 :
107+ p_kt1 = tl .make_block_ptr (k , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + BC ), (BK , BC ), (0 , 1 ))
108+ p_gk1 = tl .make_block_ptr (g , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + BC ), (BK , BC ), (0 , 1 ))
109+ b_gk1 = tl .load (p_gk1 , boundary_check = (0 , 1 ))
110+ b_kt1 = tl .load (p_kt1 , boundary_check = (0 , 1 ))
111+ b_ktg1 = b_kt1 * exp (b_gn [:, None ] - b_gk1 )
112+ b_Aqk1 += tl .dot (b_qg , b_ktg1 )
113+ b_Akk1 += tl .dot (b_kg , b_ktg1 )
114+ if i_i > 2 :
115+ p_kt2 = tl .make_block_ptr (k , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + 2 * BC ), (BK , BC ), (0 , 1 ))
116+ p_gk2 = tl .make_block_ptr (g , (K , T ), (1 , H * K ), (i_k * BK , i_t * BT + 2 * BC ), (BK , BC ), (0 , 1 ))
117+ b_gk2 = tl .load (p_gk2 , boundary_check = (0 , 1 ))
118+ b_kt2 = tl .load (p_kt2 , boundary_check = (0 , 1 ))
119+ b_ktg2 = b_kt2 * exp (b_gn [:, None ] - b_gk2 )
120+ b_Aqk2 += tl .dot (b_qg , b_ktg2 )
121+ b_Akk2 += tl .dot (b_kg , b_ktg2 )
122+
123+ if i_tn < T :
124+ b_gn2 = tl .load (g + i_tn * H * K + o_k , mask = m_k , other = 0 )
125+ b_gqk2 = exp (b_g - b_gn2 [None , :])
126+ b_ktg = tl .trans (b_k * exp (b_gn2 [None , :] - b_g ))
127+ b_Aqk += tl .dot (b_q * b_gqk2 , b_ktg )
128+ b_Akk += tl .dot (b_k * b_gqk2 , b_ktg )
90129
91- b_q = tl .load (p_q , boundary_check = (0 , 1 ))
92- b_qg = b_q * exp (b_g - b_gn [None , :]) * scale
93- b_Aqk += tl .dot (b_qg , b_ktg )
130+ if i_i > 0 :
131+ p_Aqk0 = tl .make_block_ptr (Aqk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , 0 ), (BC , BC ), (1 , 0 ))
132+ p_Akk0 = tl .make_block_ptr (Akk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , 0 ), (BC , BC ), (1 , 0 ))
133+ tl .store (p_Aqk0 , (b_Aqk0 * scale ).to (Aqk .dtype .element_ty ), boundary_check = (0 , 1 ))
134+ tl .store (p_Akk0 , (b_Akk0 * b_b [:, None ]).to (Akk .dtype .element_ty ), boundary_check = (0 , 1 ))
135+ if i_i > 1 :
136+ p_Aqk1 = tl .make_block_ptr (Aqk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , BC ), (BC , BC ), (1 , 0 ))
137+ p_Akk1 = tl .make_block_ptr (Akk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , BC ), (BC , BC ), (1 , 0 ))
138+ tl .store (p_Aqk1 , (b_Aqk1 * scale ).to (Aqk .dtype .element_ty ), boundary_check = (0 , 1 ))
139+ tl .store (p_Akk1 , (b_Akk1 * b_b [:, None ]).to (Akk .dtype .element_ty ), boundary_check = (0 , 1 ))
140+ if i_i > 2 :
141+ p_Aqk2 = tl .make_block_ptr (Aqk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , 2 * BC ), (BC , BC ), (1 , 0 ))
142+ p_Akk2 = tl .make_block_ptr (Akk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , 2 * BC ), (BC , BC ), (1 , 0 ))
143+ tl .store (p_Aqk2 , (b_Aqk2 * scale ).to (Aqk .dtype .element_ty ), boundary_check = (0 , 1 ))
144+ tl .store (p_Akk2 , (b_Akk2 * b_b [:, None ]).to (Akk .dtype .element_ty ), boundary_check = (0 , 1 ))
145+
146+ if i_tn >= T :
147+ return
148+ o_i = tl .arange (0 , BC )
149+ m_A = (o_i >= BC2 )[:, None ] & (o_i < BC2 )
94150
95- b_Akk *= b_b [:, None ]
151+ b_Aqk = tl .where (m_A , b_Aqk * scale , 0. )
152+ b_Akk = tl .where (m_A , b_Akk * b_b [:, None ], 0. )
96153
97- p_Akk = tl .make_block_ptr (Akk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , i_j * BC ), (BC , BC ), (1 , 0 ))
98- tl .store (p_Akk , b_Akk .to (Akk .dtype .element_ty ), boundary_check = (0 , 1 ))
99- p_Aqk = tl .make_block_ptr (Aqk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , i_j * BC ), (BC , BC ), (1 , 0 ))
154+ p_Aqk = tl .make_block_ptr (Aqk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , i_i * BC ), (BC , BC ), (1 , 0 ))
155+ p_Akk = tl .make_block_ptr (Akk , (T , BT ), (H * BT , 1 ), (i_t * BT + i_i * BC , i_i * BC ), (BC , BC ), (1 , 0 ))
100156 tl .store (p_Aqk , b_Aqk .to (Aqk .dtype .element_ty ), boundary_check = (0 , 1 ))
157+ tl .store (p_Akk , b_Akk .to (Akk .dtype .element_ty ), boundary_check = (0 , 1 ))
101158
102159
103160@triton .heuristics ({
@@ -438,14 +495,14 @@ def chunk_kda_fwd_intra(
438495 chunk_indices = prepare_chunk_indices (cu_seqlens , BT )
439496 NT = triton .cdiv (T , BT ) if cu_seqlens is None else len (chunk_indices )
440497
441- BC = min ( 16 , BT )
498+ BC , BC2 = 16 , 8
442499 NC = triton .cdiv (BT , BC )
443500 BK = max (triton .next_power_of_2 (K ), 16 )
444501
445502 Aqk = torch .empty (B , T , H , BT , device = k .device , dtype = output_dtype )
446503 Akk = torch .empty (B , T , H , BT , device = k .device , dtype = output_dtype )
447- grid = (NT , NC * NC , B * H )
448504
505+ grid = (NC , NT , B * H )
449506 chunk_kda_fwd_kernel_intra_sub_inter [grid ](
450507 q = q ,
451508 k = k ,
@@ -461,6 +518,7 @@ def chunk_kda_fwd_intra(
461518 K = K ,
462519 BT = BT ,
463520 BC = BC ,
521+ BC2 = BC2 ,
464522 NC = NC ,
465523 )
466524
@@ -476,6 +534,7 @@ def chunk_kda_fwd_intra(
476534 scale = scale ,
477535 cu_seqlens = cu_seqlens ,
478536 chunk_size = BT ,
537+ sub_chunk_size = BC2 ,
479538 )
480539 else :
481540 # Original sub-chunk based implementation
@@ -494,7 +553,7 @@ def chunk_kda_fwd_intra(
494553 H = H ,
495554 K = K ,
496555 BT = BT ,
497- BC = BC ,
556+ BC = BC2 ,
498557 BK = BK ,
499558 )
500559
@@ -519,8 +578,8 @@ def chunk_kda_bwd_intra(
519578 db : torch .Tensor ,
520579 dg : torch .Tensor ,
521580 cu_seqlens : torch .LongTensor | None = None ,
522- chunk_size : int = 64 ,
523581 chunk_indices : torch .LongTensor | None = None ,
582+ chunk_size : int = 64 ,
524583):
525584 B , T , H , K = k .shape
526585 BT = chunk_size
@@ -530,7 +589,6 @@ def chunk_kda_bwd_intra(
530589 if chunk_indices is None and cu_seqlens is not None :
531590 chunk_indices = prepare_chunk_indices (cu_seqlens , BT )
532591 NT = triton .cdiv (T , BT ) if cu_seqlens is None else len (chunk_indices )
533- # NC = 4
534592 NC = triton .cdiv (BT , BC )
535593 NK = triton .cdiv (K , BK )
536594
0 commit comments