Skip to content

Commit d4b33a3

Browse files
authored
[KDA] Faster inter computation in 64x64 intra fwd (#658)
* Update intra inter impls * Suooprt passing headdim * Minor fix * Fix potential OOD
1 parent 0a87b0f commit d4b33a3

File tree

3 files changed

+112
-52
lines changed

3 files changed

+112
-52
lines changed

benchmarks/benchmark_training_throughput.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def profile(
5656
context_len: int = 2048,
5757
varlen: bool = False,
5858
num_heads: int | None = None,
59+
head_dim: int | None = None,
5960
num_hidden_layers: int | None = None,
6061
warmup_steps: int = 16,
6162
steps: int = 32,
@@ -71,6 +72,9 @@ def profile(
7172
config = configs[name] if name in configs else AutoConfig.from_pretrained(name)
7273
if num_heads is not None:
7374
config.num_heads = num_heads
75+
if head_dim is not None:
76+
config.head_dim = head_dim
77+
config.hidden_size = config.num_heads * config.head_dim
7478
if num_hidden_layers is not None:
7579
config.num_hidden_layers = num_hidden_layers
7680
model = AutoModelForCausalLM.from_config(config).cuda().to(dtype)
@@ -147,6 +151,7 @@ def profile(
147151
parser.add_argument("--context_len", default=None, type=int)
148152
parser.add_argument("--varlen", action='store_true')
149153
parser.add_argument("--num_heads", default=None, type=int)
154+
parser.add_argument("--head_dim", default=None, type=int)
150155
parser.add_argument("--num_hidden_layers", default=None, type=int)
151156
parser.add_argument("--warmup_steps", default=64, type=int)
152157
parser.add_argument("--steps", default=256, type=int)
@@ -159,6 +164,7 @@ def profile(
159164
context_len=args.context_len,
160165
varlen=args.varlen,
161166
num_heads=args.num_heads,
167+
head_dim=args.head_dim,
162168
num_hidden_layers=args.num_hidden_layers,
163169
warmup_steps=args.warmup_steps,
164170
steps=args.steps,

fla/ops/kda/chunk_intra.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,25 @@ def chunk_kda_fwd_kernel_intra_sub_inter(
3939
K: tl.constexpr,
4040
BT: tl.constexpr,
4141
BC: tl.constexpr,
42+
BC2: tl.constexpr,
4243
BK: tl.constexpr,
4344
NC: tl.constexpr,
4445
IS_VARLEN: tl.constexpr,
4546
):
46-
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
47+
i_i, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
4748
i_b, i_h = i_bh // H, i_bh % H
48-
i_i, i_j = i_c // NC, i_c % NC
4949
if IS_VARLEN:
5050
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
5151
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
5252
T = eos - bos
5353
else:
5454
bos, eos = i_b * T, i_b * T + T
5555

56-
if i_t * BT + i_i * BC >= T:
57-
return
58-
if i_i <= i_j:
56+
tl.static_assert(NC <= 4, "This kernel is specialized for NC <= 4")
57+
58+
i_ti = i_t * BT + i_i * BC
59+
i_tn = i_ti + BC2
60+
if i_ti >= T:
5961
return
6062

6163
q += (bos * H + i_h) * K
@@ -64,40 +66,95 @@ def chunk_kda_fwd_kernel_intra_sub_inter(
6466
Aqk += (bos * H + i_h) * BT
6567
Akk += (bos * H + i_h) * BT
6668

67-
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
69+
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_ti,), (BC,), (0,))
6870
b_b = tl.load(p_b, boundary_check=(0,))
6971

7072
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
7173
b_Akk = tl.zeros([BC, BC], dtype=tl.float32)
74+
75+
b_Aqk0 = tl.zeros([BC, BC], dtype=tl.float32)
76+
b_Akk0 = tl.zeros([BC, BC], dtype=tl.float32)
77+
b_Aqk1 = tl.zeros([BC, BC], dtype=tl.float32)
78+
b_Akk1 = tl.zeros([BC, BC], dtype=tl.float32)
79+
b_Aqk2 = tl.zeros([BC, BC], dtype=tl.float32)
80+
b_Akk2 = tl.zeros([BC, BC], dtype=tl.float32)
7281
for i_k in range(tl.cdiv(K, BK)):
73-
p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
82+
p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0))
83+
p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0))
84+
p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0))
7485
o_k = i_k * BK + tl.arange(0, BK)
7586
m_k = o_k < K
76-
# [BK,]
77-
b_gn = tl.load(g + (i_t * BT + i_i * BC) * H*K + o_k, mask=m_k, other=0)
87+
7888
# [BC, BK]
79-
p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80-
p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
81-
b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82-
p_gk = tl.make_block_ptr(g, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
89+
b_q = tl.load(p_q, boundary_check=(0, 1))
90+
b_k = tl.load(p_k, boundary_check=(0, 1))
8391
b_g = tl.load(p_g, boundary_check=(0, 1))
84-
b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
85-
b_gk = tl.load(p_gk, boundary_check=(0, 1))
86-
b_kt = tl.load(b_kt, boundary_check=(0, 1))
87-
# [BC, BC]
88-
b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
89-
b_Akk += tl.dot(b_k, b_ktg)
92+
# [BK,]
93+
b_gn = tl.load(g + i_ti * H*K + o_k, mask=m_k, other=0)
94+
# [BC, BK]
95+
b_gqk = exp(b_g - b_gn[None, :])
96+
b_qg = b_q * b_gqk
97+
b_kg = b_k * b_gqk
98+
if i_i > 0:
99+
p_kt0 = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
100+
p_gk0 = tl.make_block_ptr(g, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
101+
b_kt0 = tl.load(p_kt0, boundary_check=(0, 1))
102+
b_gk0 = tl.load(p_gk0, boundary_check=(0, 1))
103+
b_ktg0 = b_kt0 * exp(b_gn[:, None] - b_gk0)
104+
b_Aqk0 += tl.dot(b_qg, b_ktg0)
105+
b_Akk0 += tl.dot(b_kg, b_ktg0)
106+
if i_i > 1:
107+
p_kt1 = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
108+
p_gk1 = tl.make_block_ptr(g, (K, T), (1, H*K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
109+
b_gk1 = tl.load(p_gk1, boundary_check=(0, 1))
110+
b_kt1 = tl.load(p_kt1, boundary_check=(0, 1))
111+
b_ktg1 = b_kt1 * exp(b_gn[:, None] - b_gk1)
112+
b_Aqk1 += tl.dot(b_qg, b_ktg1)
113+
b_Akk1 += tl.dot(b_kg, b_ktg1)
114+
if i_i > 2:
115+
p_kt2 = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + 2 * BC), (BK, BC), (0, 1))
116+
p_gk2 = tl.make_block_ptr(g, (K, T), (1, H*K), (i_k * BK, i_t * BT + 2 * BC), (BK, BC), (0, 1))
117+
b_gk2 = tl.load(p_gk2, boundary_check=(0, 1))
118+
b_kt2 = tl.load(p_kt2, boundary_check=(0, 1))
119+
b_ktg2 = b_kt2 * exp(b_gn[:, None] - b_gk2)
120+
b_Aqk2 += tl.dot(b_qg, b_ktg2)
121+
b_Akk2 += tl.dot(b_kg, b_ktg2)
122+
123+
if i_tn < T:
124+
b_gn2 = tl.load(g + i_tn * H*K + o_k, mask=m_k, other=0)
125+
b_gqk2 = exp(b_g - b_gn2[None, :])
126+
b_ktg = tl.trans(b_k * exp(b_gn2[None, :] - b_g))
127+
b_Aqk += tl.dot(b_q * b_gqk2, b_ktg)
128+
b_Akk += tl.dot(b_k * b_gqk2, b_ktg)
90129

91-
b_q = tl.load(p_q, boundary_check=(0, 1))
92-
b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
93-
b_Aqk += tl.dot(b_qg, b_ktg)
130+
if i_i > 0:
131+
p_Aqk0 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, 0), (BC, BC), (1, 0))
132+
p_Akk0 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, 0), (BC, BC), (1, 0))
133+
tl.store(p_Aqk0, (b_Aqk0 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1))
134+
tl.store(p_Akk0, (b_Akk0 * b_b[:, None]).to(Akk.dtype.element_ty), boundary_check=(0, 1))
135+
if i_i > 1:
136+
p_Aqk1 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, BC), (BC, BC), (1, 0))
137+
p_Akk1 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, BC), (BC, BC), (1, 0))
138+
tl.store(p_Aqk1, (b_Aqk1 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1))
139+
tl.store(p_Akk1, (b_Akk1 * b_b[:, None]).to(Akk.dtype.element_ty), boundary_check=(0, 1))
140+
if i_i > 2:
141+
p_Aqk2 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, 2 * BC), (BC, BC), (1, 0))
142+
p_Akk2 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, 2 * BC), (BC, BC), (1, 0))
143+
tl.store(p_Aqk2, (b_Aqk2 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1))
144+
tl.store(p_Akk2, (b_Akk2 * b_b[:, None]).to(Akk.dtype.element_ty), boundary_check=(0, 1))
145+
146+
if i_tn >= T:
147+
return
148+
o_i = tl.arange(0, BC)
149+
m_A = (o_i >= BC2)[:, None] & (o_i < BC2)
94150

95-
b_Akk *= b_b[:, None]
151+
b_Aqk = tl.where(m_A, b_Aqk * scale, 0.)
152+
b_Akk = tl.where(m_A, b_Akk * b_b[:, None], 0.)
96153

97-
p_Akk = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
98-
tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1))
99-
p_Aqk = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
154+
p_Aqk = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_i * BC), (BC, BC), (1, 0))
155+
p_Akk = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_i * BC), (BC, BC), (1, 0))
100156
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
157+
tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1))
101158

102159

103160
@triton.heuristics({
@@ -438,14 +495,14 @@ def chunk_kda_fwd_intra(
438495
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
439496
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
440497

441-
BC = min(16, BT)
498+
BC, BC2 = 16, 8
442499
NC = triton.cdiv(BT, BC)
443500
BK = max(triton.next_power_of_2(K), 16)
444501

445502
Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
446503
Akk = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
447-
grid = (NT, NC * NC, B * H)
448504

505+
grid = (NC, NT, B * H)
449506
chunk_kda_fwd_kernel_intra_sub_inter[grid](
450507
q=q,
451508
k=k,
@@ -461,6 +518,7 @@ def chunk_kda_fwd_intra(
461518
K=K,
462519
BT=BT,
463520
BC=BC,
521+
BC2=BC2,
464522
NC=NC,
465523
)
466524

@@ -476,6 +534,7 @@ def chunk_kda_fwd_intra(
476534
scale=scale,
477535
cu_seqlens=cu_seqlens,
478536
chunk_size=BT,
537+
sub_chunk_size=BC2,
479538
)
480539
else:
481540
# Original sub-chunk based implementation
@@ -494,7 +553,7 @@ def chunk_kda_fwd_intra(
494553
H=H,
495554
K=K,
496555
BT=BT,
497-
BC=BC,
556+
BC=BC2,
498557
BK=BK,
499558
)
500559

@@ -519,8 +578,8 @@ def chunk_kda_bwd_intra(
519578
db: torch.Tensor,
520579
dg: torch.Tensor,
521580
cu_seqlens: torch.LongTensor | None = None,
522-
chunk_size: int = 64,
523581
chunk_indices: torch.LongTensor | None = None,
582+
chunk_size: int = 64,
524583
):
525584
B, T, H, K = k.shape
526585
BT = chunk_size
@@ -530,7 +589,6 @@ def chunk_kda_bwd_intra(
530589
if chunk_indices is None and cu_seqlens is not None:
531590
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
532591
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
533-
# NC = 4
534592
NC = triton.cdiv(BT, BC)
535593
NK = triton.cdiv(K, BK)
536594

fla/ops/kda/chunk_intra_token_parallel.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def chunk_kda_fwd_kernel_intra_token_parallel(
3636
H: tl.constexpr,
3737
K: tl.constexpr,
3838
BT: tl.constexpr,
39+
BC: tl.constexpr,
3940
BH: tl.constexpr,
4041
USE_EXP2: tl.constexpr,
4142
IS_VARLEN: tl.constexpr,
@@ -71,7 +72,7 @@ def chunk_kda_fwd_kernel_intra_token_parallel(
7172
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
7273
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
7374
i_t = i_tg - bos
74-
T = eos - bos # Current sequence length
75+
T = eos - bos # Current sequence length
7576

7677
# Safety check
7778
if i_t >= T or i_tg >= eos:
@@ -85,8 +86,6 @@ def chunk_kda_fwd_kernel_intra_token_parallel(
8586
if i_t >= T:
8687
return
8788

88-
# Find which sub-chunk (BC=16) this token belongs to
89-
BC: tl.constexpr = 16
9089
i_chunk = i_t // BT # which BT=64 chunk
9190
i_subchunk = (i_t % BT) // BC # which BC=16 sub-chunk within the BT chunk
9291

@@ -103,18 +102,15 @@ def chunk_kda_fwd_kernel_intra_token_parallel(
103102

104103
# Load q[i_t, h:h+BH, :] - shape [BH, K]
105104
# For varlen, we use global offset: bos + i_t = i_tg
106-
p_q = tl.make_block_ptr(q + (bos + i_t) * H * K, (H, K), (K, 1),
107-
(i_h_start, 0), (BH, BK), (0, 1))
105+
p_q = tl.make_block_ptr(q + (bos + i_t) * H * K, (H, K), (K, 1), (i_h_start, 0), (BH, BK), (0, 1))
108106
b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
109107

110108
# Load g[i_t, h:h+BH, :]
111-
p_g = tl.make_block_ptr(g + (bos + i_t) * H * K, (H, K), (K, 1),
112-
(i_h_start, 0), (BH, BK), (0, 1))
109+
p_g = tl.make_block_ptr(g + (bos + i_t) * H * K, (H, K), (K, 1), (i_h_start, 0), (BH, BK), (0, 1))
113110
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
114111

115112
# Load k[i_t, h:h+BH, :] and beta[i_t, h:h+BH]
116-
p_k = tl.make_block_ptr(k + (bos + i_t) * H * K, (H, K), (K, 1),
117-
(i_h_start, 0), (BH, BK), (0, 1))
113+
p_k = tl.make_block_ptr(k + (bos + i_t) * H * K, (H, K), (K, 1), (i_h_start, 0), (BH, BK), (0, 1))
118114
b_k_self = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
119115

120116
p_beta = beta + (bos + i_t) * H + i_h_start + o_h
@@ -124,28 +120,25 @@ def chunk_kda_fwd_kernel_intra_token_parallel(
124120
for j in range(subchunk_start, tl.minimum(i_t + 1, subchunk_end)):
125121

126122
# Load k[j, h:h+BH, :] with pointer arithmetic
127-
p_k_j = tl.make_block_ptr(k + (bos + j) * H * K, (H, K), (K, 1),
128-
(i_h_start, 0), (BH, BK), (0, 1))
129-
b_k_j = tl.load(p_k_j, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
123+
p_kj = tl.make_block_ptr(k + (bos + j) * H * K, (H, K), (K, 1), (i_h_start, 0), (BH, BK), (0, 1))
124+
b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
130125

131126
# Load g[j, h:h+BH, :]
132-
p_g_j = tl.make_block_ptr(g + (bos + j) * H * K, (H, K), (K, 1),
133-
(i_h_start, 0), (BH, BK), (0, 1))
134-
b_g_j = tl.load(p_g_j, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
127+
p_gj = tl.make_block_ptr(g + (bos + j) * H * K, (H, K), (K, 1), (i_h_start, 0), (BH, BK), (0, 1))
128+
b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) # [BH, BK]
135129

136130
# Compute gated key for all BH heads: [BH, BK]
137131
if USE_EXP2:
138-
b_k_j_gated = b_k_j * exp2(b_g - b_g_j)
132+
b_kgj = b_kj * exp2(b_g - b_gj)
139133
else:
140-
b_k_j_gated = b_k_j * exp(b_g - b_g_j)
134+
b_kgj = b_kj * exp(b_g - b_gj)
141135

142136
# Apply mask for valid K dimension
143-
b_k_j_gated = tl.where(m_k[None, :], b_k_j_gated, 0.0)
137+
b_kgj = tl.where(m_k[None, :], b_kgj, 0.0)
144138

145-
# Compute Aqk and Akk for all BH heads: [BH]
146-
b_Aqk = tl.sum(b_q * b_k_j_gated, axis=1) * scale # [BH]
139+
b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale # [BH]
147140
# Akk: only accumulate if j < i_t
148-
b_Akk = tl.sum(b_k_self * b_k_j_gated, axis=1) * tl.where(j < i_t, 1.0, 0.0) # [BH]
141+
b_Akk = tl.sum(b_k_self * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) # [BH]
149142

150143
# Store with [B, T, H, BT] layout (no transpose needed later)
151144
j_pos = j % BT
@@ -165,6 +158,7 @@ def chunk_kda_fwd_intra_token_parallel(
165158
scale: float,
166159
cu_seqlens: torch.LongTensor | None = None,
167160
chunk_size: int = 64,
161+
sub_chunk_size: int = 16,
168162
use_exp2: bool = False,
169163
) -> None:
170164
"""
@@ -187,6 +181,7 @@ def chunk_kda_fwd_intra_token_parallel(
187181
"""
188182
B, T, H, K = q.shape
189183
BT = chunk_size
184+
BC = sub_chunk_size
190185

191186
# Grid: (total_tokens, H/BH) - each token gets its own block
192187
if cu_seqlens is not None:
@@ -215,5 +210,6 @@ def grid(meta):
215210
H=H,
216211
K=K,
217212
BT=BT,
213+
BC=BC,
218214
USE_EXP2=use_exp2,
219215
)

0 commit comments

Comments
 (0)