Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions fla/ops/deltaformer/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parallel_deltaformer_fwd_kernel(
)
q = tl.load(q_blk_ptr, boundary_check=(0,))

for kv_i in range(0, T, BLOCK_T):
for kv_i in range(0, T-C, BLOCK_T):
k_blk_ptr = tl.make_block_ptr(
base=k_ptr + pid_h * D,
shape=(D, T),
Expand All @@ -179,10 +179,6 @@ def parallel_deltaformer_fwd_kernel(
k = tl.load(k_blk_ptr, boundary_check=(1,))
qk = tl.dot(q, k) * qk_scale

if kv_i >= T - C:
mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
qk = tl.where(mask, -1e6, qk)

rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
qk -= rowmax_i[:, None]
p = tl.math.exp2(qk)
Expand All @@ -193,17 +189,41 @@ def parallel_deltaformer_fwd_kernel(
acc = acc * alpha[:, None]
rowmax = rowmax_i

if kv_i < T - C:
u_blk_ptr = tl.make_block_ptr(
base=u_ptr + pid_h * D,
shape=(T, D),
strides=(H * D, 1),
offsets=(kv_i, 0),
block_shape=(BLOCK_T, D),
order=(1, 0),
)
u = tl.load(u_blk_ptr, boundary_check=(0,))
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)
u_blk_ptr = tl.make_block_ptr(
base=u_ptr + pid_h * D,
shape=(T, D),
strides=(H * D, 1),
offsets=(kv_i, 0),
block_shape=(BLOCK_T, D),
order=(1, 0),
)
u = tl.load(u_blk_ptr, boundary_check=(0,))
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove trailing whitespace.

The pre-commit hook detected trailing whitespace on this line, which should be removed to pass the linting check.

🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 202, there is trailing whitespace
at the end of the line; remove the extra space(s) or tab characters at the end
of that line (or run the project's pre-commit/formatter) so the file has no
trailing whitespace and the linter/pre-commit check will pass.

for kv_i in range(T-C, T, BLOCK_T):
k_blk_ptr = tl.make_block_ptr(
base=k_ptr + pid_h * D,
shape=(D, T),
strides=(1, H * D),
offsets=(0, kv_i),
block_shape=(D, BLOCK_T),
order=(0, 1),
)
k = tl.load(k_blk_ptr, boundary_check=(1,))
qk = tl.dot(q, k) * qk_scale

mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
qk = tl.where(mask, -1e6, qk)

rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
qk -= rowmax_i[:, None]
p = tl.math.exp2(qk)

rowsum_i = tl.sum(p, axis=1)
alpha = tl.math.exp2(rowmax - rowmax_i)
rowsum = rowsum * alpha + rowsum_i
acc = acc * alpha[:, None]
rowmax = rowmax_i
Comment on lines +203 to +226

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new loop, part of the loop-splitting optimization, introduces a significant amount of code duplication from the preceding loop (lines 170-201). Specifically, the logic for creating k_blk_ptr, loading k, computing qk, and the online softmax update (rowmax_i to rowmax = rowmax_i) is almost identical. While splitting loops to avoid branching is a valid performance strategy in Triton, this large-scale duplication can make the code harder to read, maintain, and debug. It would be beneficial to explore if there's a way to reduce this duplication without sacrificing performance, perhaps by refactoring common parts if Triton's JIT compiler allows.


lse = rowmax + tl.math.log2(rowsum)
lse_block_ptr = lse_ptr + pid_h + rowid_block * H
Expand All @@ -218,7 +238,7 @@ def parallel_deltaformer_fwd_kernel(
block_shape=(BLOCK_C, D),
order=(1, 0),
)
acc = acc / rowsum[:, None]
acc = acc / (rowsum[:, None] + 1e-9)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Good addition of epsilon for numerical stability.

Adding 1e-9 to the denominator prevents potential division by zero or NaN when rowsum is very small. This is a standard numerical stability practice.

Verify that similar normalization divisions elsewhere in the codebase (e.g., in backward passes or other kernels) also include epsilon guards for consistency:


🏁 Script executed:

#!/bin/bash
# Description: Search for other normalization divisions that might need epsilon

rg -n "/ *(rowsum|sum|norm)" --type=py -C 2 -g "*.py"

Length of output: 408


Add epsilon to the p normalization for consistency.

The division at line 280

p = tl.math.exp2(qk) / rowsum[:, None]

lacks the +1e-9 guard and can still produce infinities or NaNs when rowsum is zero. Change it to:

- p = tl.math.exp2(qk) / rowsum[:, None]
+ p = tl.math.exp2(qk) / (rowsum[:, None] + 1e-9)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 241, the computation that forms p
divides by rowsum[:, None] without the small epsilon guard; update the division
to add the same +1e-9 epsilon used elsewhere (i.e., divide by rowsum[:, None] +
1e-9) to prevent infinities/NaNs when rowsum is zero and keep numeric stability
consistent across the file.


beta_ptr = tl.make_block_ptr(
base=beta_ptr + pid_h,
Expand Down Expand Up @@ -861,7 +881,7 @@ def _forward_impl(
betai = beta_full[b, i:i + Ci, :]

w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai)
w = w * betai.unsqueeze(-1)
w = w * betai.unsqueeze(-1).to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Casting betai to torch.float32 here is a good change for maintaining precision during the multiplication. However, there is a similar calculation in the cu_seqlens branch of this function that was not updated. At line 944, the code is still w = w * betai.unsqueeze(-1). For consistency and to prevent potential numerical issues in that execution path, you should probably apply the same .to(torch.float32) cast there as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ensure consistent dtype casting across all code paths.

The explicit .to(torch.float32) cast on betai at line 884 ensures consistent dtype handling when multiplying with w. However, line 944 performs the same operation without the float32 cast:

w = w * betai.unsqueeze(-1)  # line 944, missing .to(torch.float32)

This inconsistency could lead to different numeric behavior or dtype mismatches between the standard and cu_seqlens code paths.

Apply this diff to maintain consistency:

-                w = w * betai.unsqueeze(-1)
+                w = w * betai.unsqueeze(-1).to(torch.float32)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around lines 884 and 944, the multiplication
uses betai.unsqueeze(-1).to(torch.float32) at line 884 but at line 944 the
.to(torch.float32) cast is missing; update the line 944 multiplication to cast
betai to torch.float32 (e.g., betai.unsqueeze(-1).to(torch.float32)) so both
code paths use the same dtype and avoid mismatches.

if need_aux:
wpad = torch.zeros(C, H, C, device=ko.device, dtype=ko.dtype)
wpad[:Ci, :, :Ci].copy_(w)
Expand Down
Loading