diff --git a/fla/ops/deltaformer/parallel.py b/fla/ops/deltaformer/parallel.py index 25a4d0a33..6b1b882be 100644 --- a/fla/ops/deltaformer/parallel.py +++ b/fla/ops/deltaformer/parallel.py @@ -167,7 +167,7 @@ def parallel_deltaformer_fwd_kernel( ) q = tl.load(q_blk_ptr, boundary_check=(0,)) - for kv_i in range(0, T, BLOCK_T): + for kv_i in range(0, T-C, BLOCK_T): k_blk_ptr = tl.make_block_ptr( base=k_ptr + pid_h * D, shape=(D, T), @@ -179,10 +179,6 @@ def parallel_deltaformer_fwd_kernel( k = tl.load(k_blk_ptr, boundary_check=(1,)) qk = tl.dot(q, k) * qk_scale - if kv_i >= T - C: - mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) - qk = tl.where(mask, -1e6, qk) - rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1)) qk -= rowmax_i[:, None] p = tl.math.exp2(qk) @@ -193,17 +189,41 @@ def parallel_deltaformer_fwd_kernel( acc = acc * alpha[:, None] rowmax = rowmax_i - if kv_i < T - C: - u_blk_ptr = tl.make_block_ptr( - base=u_ptr + pid_h * D, - shape=(T, D), - strides=(H * D, 1), - offsets=(kv_i, 0), - block_shape=(BLOCK_T, D), - order=(1, 0), - ) - u = tl.load(u_blk_ptr, boundary_check=(0,)) - acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc) + u_blk_ptr = tl.make_block_ptr( + base=u_ptr + pid_h * D, + shape=(T, D), + strides=(H * D, 1), + offsets=(kv_i, 0), + block_shape=(BLOCK_T, D), + order=(1, 0), + ) + u = tl.load(u_blk_ptr, boundary_check=(0,)) + acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc) + + for kv_i in range(T-C, T, BLOCK_T): + k_blk_ptr = tl.make_block_ptr( + base=k_ptr + pid_h * D, + shape=(D, T), + strides=(1, H * D), + offsets=(0, kv_i), + block_shape=(D, BLOCK_T), + order=(0, 1), + ) + k = tl.load(k_blk_ptr, boundary_check=(1,)) + qk = tl.dot(q, k) * qk_scale + + mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) + qk = tl.where(mask, -1e6, qk) + + rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1)) + qk -= rowmax_i[:, None] + p = tl.math.exp2(qk) + + rowsum_i = tl.sum(p, axis=1) + alpha = tl.math.exp2(rowmax - rowmax_i) + rowsum = rowsum * alpha + rowsum_i + acc = acc * alpha[:, None] + rowmax = rowmax_i lse = rowmax + tl.math.log2(rowsum) lse_block_ptr = lse_ptr + pid_h + rowid_block * H @@ -218,7 +238,7 @@ def parallel_deltaformer_fwd_kernel( block_shape=(BLOCK_C, D), order=(1, 0), ) - acc = acc / rowsum[:, None] + acc = acc / (rowsum[:, None] + 1e-9) beta_ptr = tl.make_block_ptr( base=beta_ptr + pid_h, @@ -861,7 +881,7 @@ def _forward_impl( betai = beta_full[b, i:i + Ci, :] w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai) - w = w * betai.unsqueeze(-1) + w = w * betai.unsqueeze(-1).to(torch.float32) if need_aux: wpad = torch.zeros(C, H, C, device=ko.device, dtype=ko.dtype) wpad[:Ci, :, :Ci].copy_(w)