Skip to content

Conversation

foreverpiano
Copy link

@foreverpiano foreverpiano commented Sep 30, 2025

if-else optimization; change w to fp32; add 1e-9 to avoid nan

Summary by CodeRabbit

  • Bug Fixes
    • Improved numerical stability by adding a small epsilon to normalization.
    • Ensured consistent dtype handling with explicit float32 casting across forward paths.
    • Corrected masking and boundary handling for late-token processing to prevent incorrect updates.
  • Refactor
    • Restructured forward-pass processing into two phases for clearer control flow and consistent masking.
    • Simplified data flow by lifting computations out of inner loops for readability and maintainability.

Copy link
Contributor

coderabbitai bot commented Sep 30, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Refactors the forward kernel to split kv processing into two phases (0..T−C without masking, then T−C..T with masking), adjusts boundary handling, lifts the u‑path, adds 1e-9 epsilons to normalization, and standardizes dtype casting by applying explicit float32 casts to betai before weight multiplications across forward paths.

Changes

Cohort / File(s) Summary
Kernel forward-pass refactor
fla/ops/deltaformer/parallel.py
Splits parallel_deltaformer_fwd_kernel into two loops: initial unmasked range and tail masked range; moves u‑path out of the first loop; separates k-load in tail; updates rowmax/rowsum/acc accordingly; adds 1e-9 epsilon to normalization divisions for stability.
Forward dtype consistency
fla/ops/deltaformer/parallel.py
Ensures betai is cast to float32 before broadcasting and multiplying with w in both per-chunk and cu-seqlens branches; propagates casted weight multiplications consistently.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant F as Forward Pass
  participant K as Fwd Kernel
  participant L1 as Initial Loop (0..T−C)
  participant L2 as Tail Loop (T−C..T, masked)

  F->>K: Launch parallel_deltaformer_fwd_kernel(...)
  activate K

  K->>L1: Process tokens without masking
  note right of L1: u-path computed unconditionally<br/>rowmax/rowsum/acc updates<br/>normalize with eps (1e-9)
  L1-->>K: Partial accumulators

  K->>L2: Process remaining tokens with fixed mask
  note right of L2: Separate k-load<br/>masked updates<br/>normalize with eps (1e-9)
  L2-->>K: Final accumulators

  deactivate K
  K-->>F: w (with dtype-consistent betai scaling)

  rect rgba(230,245,255,0.6)
  note over F: In Python forward\nbetai -> float32, broadcast, multiply w\n(applies to per-chunk and cu-seqlens)
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibble through loops, two hops, not one,
Mask at the tail, then presto—done.
A pinch of eps for stable brew,
Casts to float32, crisp and true.
With rows that sum and max align,
My kernels hum—oh how they shine! 🐇⚙️

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ⚠️ Warning The current title is overly long and lists multiple distinct changes rather than focusing on the primary improvement, making it harder to quickly discern the main purpose of the pull request. It reads like a changelog entry instead of a concise summary sentence that highlights the key update. Please shorten the title to a single concise sentence that emphasizes the main enhancement, for example: “Improve Deltaformer kernel performance and numerical stability.”
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

Summary of Changes

Hello @foreverpiano, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces several minor but impactful improvements to the Deltaformer implementation. The primary goal is to enhance performance through an if-else optimization in the kernel and improve numerical stability by explicitly managing floating-point precision and preventing division by zero, ultimately leading to more robust and efficient computations.

Highlights

  • If-Else Optimization: The parallel_deltaformer_fwd_kernel has been refactored to optimize conditional logic by splitting the processing loop into two distinct parts, eliminating an if-else branch within the main loop.
  • Floating Point Precision: The w tensor is now explicitly cast to torch.float32 before multiplication with betai in the _forward_impl function, ensuring consistent floating-point precision.
  • Numerical Stability: A small epsilon value (1e-9) has been added to the denominator during the acc calculation in parallel_deltaformer_fwd_kernel to prevent potential division by zero errors and NaN values.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@foreverpiano
Copy link
Author

foreverpiano commented Sep 30, 2025

@yzhangcs Any plan for adding tma torch_desc version? I have got one for this. @Nathancgy

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several improvements to the Deltaformer implementation. It optimizes the forward kernel by splitting a loop to remove an if-else branch, which should improve performance. It also enhances numerical stability by adding a small epsilon to a denominator to prevent division by zero, and by casting a tensor to float32 before multiplication to maintain precision. My review focuses on two points: a significant amount of code duplication introduced by the loop optimization which could affect maintainability, and an inconsistency where a precision-improving cast was not applied to a similar code path elsewhere in the file. Addressing these points would further improve the code quality.


w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai)
w = w * betai.unsqueeze(-1)
w = w * betai.unsqueeze(-1).to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Casting betai to torch.float32 here is a good change for maintaining precision during the multiplication. However, there is a similar calculation in the cu_seqlens branch of this function that was not updated. At line 944, the code is still w = w * betai.unsqueeze(-1). For consistency and to prevent potential numerical issues in that execution path, you should probably apply the same .to(torch.float32) cast there as well.

Comment on lines +203 to +226
for kv_i in range(T-C, T, BLOCK_T):
k_blk_ptr = tl.make_block_ptr(
base=k_ptr + pid_h * D,
shape=(D, T),
strides=(1, H * D),
offsets=(0, kv_i),
block_shape=(D, BLOCK_T),
order=(0, 1),
)
k = tl.load(k_blk_ptr, boundary_check=(1,))
qk = tl.dot(q, k) * qk_scale

mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
qk = tl.where(mask, -1e6, qk)

rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
qk -= rowmax_i[:, None]
p = tl.math.exp2(qk)

rowsum_i = tl.sum(p, axis=1)
alpha = tl.math.exp2(rowmax - rowmax_i)
rowsum = rowsum * alpha + rowsum_i
acc = acc * alpha[:, None]
rowmax = rowmax_i

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new loop, part of the loop-splitting optimization, introduces a significant amount of code duplication from the preceding loop (lines 170-201). Specifically, the logic for creating k_blk_ptr, loading k, computing qk, and the online softmax update (rowmax_i to rowmax = rowmax_i) is almost identical. While splitting loops to avoid branching is a valid performance strategy in Triton, this large-scale duplication can make the code harder to read, maintain, and debug. It would be beneficial to explore if there's a way to reduce this duplication without sacrificing performance, perhaps by refactoring common parts if Triton's JIT compiler allows.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c3bd565 and 01f7f9f.

📒 Files selected for processing (1)
  • fla/ops/deltaformer/parallel.py (4 hunks)
🧰 Additional context used
🪛 GitHub Actions: lint
fla/ops/deltaformer/parallel.py

[error] 1-1: Trailing whitespace detected and removed by pre-commit hook 'trailing-whitespace'.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
fla/ops/deltaformer/parallel.py (2)

170-202: LGTM! Loop splitting optimization improves performance.

The refactor splits kv processing into two phases: an unmasked loop for 0..T-C (bulk processing) and a masked loop for T-C..T (boundary handling). Moving the u-path computation into the first loop is correct, as it processes the initial range efficiently without masking overhead.


203-226: Add unit tests for causal mask boundary logic
In fla/ops/deltaformer/parallel.py (L203–226), the mask condition

(T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)

guards off-by-one behavior in the final block. No existing tests cover this—please add targeted unit tests (using small T, C, BLOCK_T values) to validate both masked and unmasked cases.

)
u = tl.load(u_blk_ptr, boundary_check=(0,))
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove trailing whitespace.

The pre-commit hook detected trailing whitespace on this line, which should be removed to pass the linting check.

🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 202, there is trailing whitespace
at the end of the line; remove the extra space(s) or tab characters at the end
of that line (or run the project's pre-commit/formatter) so the file has no
trailing whitespace and the linter/pre-commit check will pass.

order=(1, 0),
)
acc = acc / rowsum[:, None]
acc = acc / (rowsum[:, None] + 1e-9)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Good addition of epsilon for numerical stability.

Adding 1e-9 to the denominator prevents potential division by zero or NaN when rowsum is very small. This is a standard numerical stability practice.

Verify that similar normalization divisions elsewhere in the codebase (e.g., in backward passes or other kernels) also include epsilon guards for consistency:


🏁 Script executed:

#!/bin/bash
# Description: Search for other normalization divisions that might need epsilon

rg -n "/ *(rowsum|sum|norm)" --type=py -C 2 -g "*.py"

Length of output: 408


Add epsilon to the p normalization for consistency.

The division at line 280

p = tl.math.exp2(qk) / rowsum[:, None]

lacks the +1e-9 guard and can still produce infinities or NaNs when rowsum is zero. Change it to:

- p = tl.math.exp2(qk) / rowsum[:, None]
+ p = tl.math.exp2(qk) / (rowsum[:, None] + 1e-9)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around line 241, the computation that forms p
divides by rowsum[:, None] without the small epsilon guard; update the division
to add the same +1e-9 epsilon used elsewhere (i.e., divide by rowsum[:, None] +
1e-9) to prevent infinities/NaNs when rowsum is zero and keep numeric stability
consistent across the file.


w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai)
w = w * betai.unsqueeze(-1)
w = w * betai.unsqueeze(-1).to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ensure consistent dtype casting across all code paths.

The explicit .to(torch.float32) cast on betai at line 884 ensures consistent dtype handling when multiplying with w. However, line 944 performs the same operation without the float32 cast:

w = w * betai.unsqueeze(-1)  # line 944, missing .to(torch.float32)

This inconsistency could lead to different numeric behavior or dtype mismatches between the standard and cu_seqlens code paths.

Apply this diff to maintain consistency:

-                w = w * betai.unsqueeze(-1)
+                w = w * betai.unsqueeze(-1).to(torch.float32)
🤖 Prompt for AI Agents
In fla/ops/deltaformer/parallel.py around lines 884 and 944, the multiplication
uses betai.unsqueeze(-1).to(torch.float32) at line 884 but at line 944 the
.to(torch.float32) cast is missing; update the line 944 multiplication to cast
betai to torch.float32 (e.g., betai.unsqueeze(-1).to(torch.float32)) so both
code paths use the same dtype and avoid mismatches.

@foreverpiano foreverpiano changed the title [Deltaformer] Minor improvement; if-else optimization; change w to fp32; add 1e-9 to avoid nan [Deltaformer] kernel improvement; if-else optimization; change w to fp32; add 1e-9 to avoid nan Sep 30, 2025
@Nathancgy
Copy link
Contributor

@foreverpiano thanks for the pr (the two separate loops look good)! For adding tma torch_desc version, it is up to zhangyu to decide. But in the meanwhile, if you need to test the model while making further improvements, I recommend using the version from this pr, where I fixed a few accuracy issues. This version passes the ops test and should be correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants