-
Notifications
You must be signed in to change notification settings - Fork 270
[Deltaformer] kernel improvement; if-else optimization; change w to fp32; add 1e-9 to avoid nan #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughRefactors the forward kernel to split kv processing into two phases (0..T−C without masking, then T−C..T with masking), adjusts boundary handling, lifts the u‑path, adds 1e-9 epsilons to normalization, and standardizes dtype casting by applying explicit float32 casts to betai before weight multiplications across forward paths. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant F as Forward Pass
participant K as Fwd Kernel
participant L1 as Initial Loop (0..T−C)
participant L2 as Tail Loop (T−C..T, masked)
F->>K: Launch parallel_deltaformer_fwd_kernel(...)
activate K
K->>L1: Process tokens without masking
note right of L1: u-path computed unconditionally<br/>rowmax/rowsum/acc updates<br/>normalize with eps (1e-9)
L1-->>K: Partial accumulators
K->>L2: Process remaining tokens with fixed mask
note right of L2: Separate k-load<br/>masked updates<br/>normalize with eps (1e-9)
L2-->>K: Final accumulators
deactivate K
K-->>F: w (with dtype-consistent betai scaling)
rect rgba(230,245,255,0.6)
note over F: In Python forward\nbetai -> float32, broadcast, multiply w\n(applies to per-chunk and cu-seqlens)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
Summary of ChangesHello @foreverpiano, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces several minor but impactful improvements to the Deltaformer implementation. The primary goal is to enhance performance through an Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
@yzhangcs Any plan for adding tma torch_desc version? I have got one for this. @Nathancgy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several improvements to the Deltaformer implementation. It optimizes the forward kernel by splitting a loop to remove an if-else branch, which should improve performance. It also enhances numerical stability by adding a small epsilon to a denominator to prevent division by zero, and by casting a tensor to float32 before multiplication to maintain precision. My review focuses on two points: a significant amount of code duplication introduced by the loop optimization which could affect maintainability, and an inconsistency where a precision-improving cast was not applied to a similar code path elsewhere in the file. Addressing these points would further improve the code quality.
|
||
w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai) | ||
w = w * betai.unsqueeze(-1) | ||
w = w * betai.unsqueeze(-1).to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Casting betai
to torch.float32
here is a good change for maintaining precision during the multiplication. However, there is a similar calculation in the cu_seqlens
branch of this function that was not updated. At line 944, the code is still w = w * betai.unsqueeze(-1)
. For consistency and to prevent potential numerical issues in that execution path, you should probably apply the same .to(torch.float32)
cast there as well.
for kv_i in range(T-C, T, BLOCK_T): | ||
k_blk_ptr = tl.make_block_ptr( | ||
base=k_ptr + pid_h * D, | ||
shape=(D, T), | ||
strides=(1, H * D), | ||
offsets=(0, kv_i), | ||
block_shape=(D, BLOCK_T), | ||
order=(0, 1), | ||
) | ||
k = tl.load(k_blk_ptr, boundary_check=(1,)) | ||
qk = tl.dot(q, k) * qk_scale | ||
|
||
mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1) | ||
qk = tl.where(mask, -1e6, qk) | ||
|
||
rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1)) | ||
qk -= rowmax_i[:, None] | ||
p = tl.math.exp2(qk) | ||
|
||
rowsum_i = tl.sum(p, axis=1) | ||
alpha = tl.math.exp2(rowmax - rowmax_i) | ||
rowsum = rowsum * alpha + rowsum_i | ||
acc = acc * alpha[:, None] | ||
rowmax = rowmax_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new loop, part of the loop-splitting optimization, introduces a significant amount of code duplication from the preceding loop (lines 170-201). Specifically, the logic for creating k_blk_ptr
, loading k
, computing qk
, and the online softmax update (rowmax_i
to rowmax = rowmax_i
) is almost identical. While splitting loops to avoid branching is a valid performance strategy in Triton, this large-scale duplication can make the code harder to read, maintain, and debug. It would be beneficial to explore if there's a way to reduce this duplication without sacrificing performance, perhaps by refactoring common parts if Triton's JIT compiler allows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/deltaformer/parallel.py
(4 hunks)
🧰 Additional context used
🪛 GitHub Actions: lint
fla/ops/deltaformer/parallel.py
[error] 1-1: Trailing whitespace detected and removed by pre-commit hook 'trailing-whitespace'.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
fla/ops/deltaformer/parallel.py (2)
170-202
: LGTM! Loop splitting optimization improves performance.The refactor splits kv processing into two phases: an unmasked loop for
0..T-C
(bulk processing) and a masked loop forT-C..T
(boundary handling). Moving the u-path computation into the first loop is correct, as it processes the initial range efficiently without masking overhead.
203-226
: Add unit tests for causal mask boundary logic
In fla/ops/deltaformer/parallel.py (L203–226), the mask condition(T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
guards off-by-one behavior in the final block. No existing tests cover this—please add targeted unit tests (using small T, C, BLOCK_T values) to validate both masked and unmasked cases.
) | ||
u = tl.load(u_blk_ptr, boundary_check=(0,)) | ||
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove trailing whitespace.
The pre-commit hook detected trailing whitespace on this line, which should be removed to pass the linting check.
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 202, there is trailing whitespace
at the end of the line; remove the extra space(s) or tab characters at the end
of that line (or run the project's pre-commit/formatter) so the file has no
trailing whitespace and the linter/pre-commit check will pass.
order=(1, 0), | ||
) | ||
acc = acc / rowsum[:, None] | ||
acc = acc / (rowsum[:, None] + 1e-9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Good addition of epsilon for numerical stability.
Adding 1e-9
to the denominator prevents potential division by zero or NaN when rowsum
is very small. This is a standard numerical stability practice.
Verify that similar normalization divisions elsewhere in the codebase (e.g., in backward passes or other kernels) also include epsilon guards for consistency:
🏁 Script executed:
#!/bin/bash
# Description: Search for other normalization divisions that might need epsilon
rg -n "/ *(rowsum|sum|norm)" --type=py -C 2 -g "*.py"
Length of output: 408
Add epsilon to the p
normalization for consistency.
The division at line 280
p = tl.math.exp2(qk) / rowsum[:, None]
lacks the +1e-9
guard and can still produce infinities or NaNs when rowsum
is zero. Change it to:
- p = tl.math.exp2(qk) / rowsum[:, None]
+ p = tl.math.exp2(qk) / (rowsum[:, None] + 1e-9)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 241, the computation that forms p
divides by rowsum[:, None] without the small epsilon guard; update the division
to add the same +1e-9 epsilon used elsewhere (i.e., divide by rowsum[:, None] +
1e-9) to prevent infinities/NaNs when rowsum is zero and keep numeric stability
consistent across the file.
|
||
w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai) | ||
w = w * betai.unsqueeze(-1) | ||
w = w * betai.unsqueeze(-1).to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure consistent dtype casting across all code paths.
The explicit .to(torch.float32)
cast on betai
at line 884 ensures consistent dtype handling when multiplying with w
. However, line 944 performs the same operation without the float32 cast:
w = w * betai.unsqueeze(-1) # line 944, missing .to(torch.float32)
This inconsistency could lead to different numeric behavior or dtype mismatches between the standard and cu_seqlens
code paths.
Apply this diff to maintain consistency:
- w = w * betai.unsqueeze(-1)
+ w = w * betai.unsqueeze(-1).to(torch.float32)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around lines 884 and 944, the multiplication
uses betai.unsqueeze(-1).to(torch.float32) at line 884 but at line 944 the
.to(torch.float32) cast is missing; update the line 944 multiplication to cast
betai to torch.float32 (e.g., betai.unsqueeze(-1).to(torch.float32)) so both
code paths use the same dtype and avoid mismatches.
@foreverpiano thanks for the pr (the two separate loops look good)! For adding tma torch_desc version, it is up to zhangyu to decide. But in the meanwhile, if you need to test the model while making further improvements, I recommend using the version from this pr, where I fixed a few accuracy issues. This version passes the ops test and should be correct. |
if-else optimization; change w to fp32; add 1e-9 to avoid nan
Summary by CodeRabbit