Skip to content

Conversation

@Pan-Yuqi
Copy link
Contributor

@Pan-Yuqi Pan-Yuqi commented Mar 15, 2025

Summary by CodeRabbit

  • New Features

    • Enhanced core operations to support an additional bias state, improving handling for variable-length sequences and bias configurations.
    • Expanded benchmarking capabilities with a new method that supplements existing performance evaluations.
  • Tests

    • Updated test cases to validate the new bias functionality and parameter naming, ensuring consistent and reliable computations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 15, 2025

Walkthrough

The changes enhance the TTT linear operations by adding a new fused function and incorporating an additional bias parameter for the initial state in multiple modules. The updates span benchmark scripts, core operation files, and tests. Function signatures in forward and backward kernels across the chunk, fused_chunk, and naive implementations now require extra parameters (such as hb0/initial_state_bias) and include revised computation logic. Test cases have been augmented to supply and validate the new bias data, ensuring that the modifications are integrated end‐to‐end.

Changes

File(s) Change Summary
benchmarks/ops/benchmark_ttt.py Added new function fused_chunk_ttt_linear to benchmarking; updated parameter naming (using chunk_size=16 rather than BT=16); modified headdim_vals and expanded the methods list to include the fused operation.
fla/ops/ttt/chunk.py, fla/ops/ttt/naive.py Updated function signatures to incorporate a new bias parameter (hb0/initial_state_bias); revised forward and backward kernel logic to include bias computations; modified return values to now include the final state bias.
fla/ops/ttt/fused_chunk.py Enhanced fused kernel functionality with additional parameters (h0, hb0, ht, hbt, offsets, etc.); updated index and conditional logic for handling initial states and variable-length sequences; adjusted both forward and backward kernel signatures and computations accordingly.
tests/ops/test_ttt.py Modified test functions to include a new bias tensor (e.g., hb0); updated parameterization (e.g., new value 16 for T) and adjusted calls to the updated TTT functions; extended assertions to verify the outputs and gradients for the new bias functionality.

Sequence Diagram(s)

sequenceDiagram
    participant U as User
    participant F as FusedChunkTTTLinear
    participant FK as FusedChunkTTTLinearFwdKernel
    participant BK as FusedChunkTTTLinearBwdKernel
    U->>+F: Call fused_chunk_ttt_linear(q, k, v, w, b, eta, ...)
    F->>+FK: Execute forward pass with (h0, hb0, ht, hbt, offsets, ...)
    FK-->>-F: Return computed output & final state bias
    alt Backward Pass
        F->>+BK: Execute backward kernels to compute gradients
        BK-->>-F: Return gradients (including dhb0)
    end
    F-->>U: Return final outputs and gradients (if computed)
Loading
sequenceDiagram
    participant B as BenchmarkTest
    participant C as ChunkTTTMethod
    participant F as FusedChunkTTTMethod
    B->>+C: Call chunk_ttt_linear with chunk_size=16
    C-->>-B: Return output and timing ("time_b")
    B->>+F: Call fused_chunk_ttt_linear with updated parameters
    F-->>-B: Return output and timing ("time_f_b")
    B-->>B: Compare benchmark results
Loading

Poem

I'm a rabbit racing through the code,
Hopping on tweaks in a new mode,
Fused functions and bias refined,
With every parameter, joy defined.
My whiskers twitch with every change,
A bunny's celebration, smooth and strange!
🐇💻 Happy hops in the code range!

Tip

⚡🧪 Multi-step agentic review comment chat (experimental)
  • We're introducing multi-step agentic chat in review comments. This experimental feature enhances review discussions with the CodeRabbit agentic chat by enabling advanced interactions, including the ability to create pull requests directly from comments.
    - To enable this feature, set early_access to true under in the settings.
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (17)
benchmarks/ops/benchmark_ttt.py (2)

35-36: Keep consistent dimension options for headdim_vals
Removing 128 from headdim_vals aligns the benchmark to use a single dimension size. If you plan to benchmark higher dimensions in the future, consider re-adding them or moving them to a config-driven approach.


88-99: Benchmark coverage for fused_chunk_ttt_linear looks solid
This block properly tests forward/backward performance for the new fused function. However, note the pipeline log shows line-length (E501) issues around these lines (e.g., line 90). Consider wrapping or splitting the lines to adhere to style guidelines.

- k = torch.nn.functional.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
+ k_t = torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype)
+ k = torch.nn.functional.normalize(k_t, p=2, dim=-1).requires_grad_(True)
🧰 Tools
🪛 GitHub Actions: pr

[error] 90-90: flake8: E501 line too long (143 > 127 characters)

tests/ops/test_ttt.py (6)

84-86: Multiple sums in backward pass
Combining multiple sums into a single backward call is fine. Watch out for flake8 line-length or spacing issues (see pipeline logs).

🧰 Tools
🪛 GitHub Actions: pr

[error] 84-84: flake8: E501 line too long (141 > 127 characters)


88-99: Reference implementation chunk_ttt_linear_ref updated with bias
Balanced approach to test reference vs. new function. Also note potential line-length warnings (E501).


122-122: Potential duplication?
Check if lines 121-122 overlap with the same test. If not necessary for a separate test, consider deduplicating.


125-125: Single scale value usage
Defining [0.1] is fine, but you can parametrize multiple scale values if needed for broader coverage.


157-161: Mapping everything to .cuda().requires_grad_(True) again
Uniform usage across code. Keep an eye on flake8 line length if these lines get longer.


162-210: Testing fused_chunk_ttt_linear with new bias parameter
The block thoroughly validates forward/backward for fused approach. Pipeline logs mention line-length warnings in this region; consider splitting lines to comply with style.

🧰 Tools
🪛 GitHub Actions: pr

[error] 174-174: flake8: E222 multiple spaces after operator


[error] 175-175: flake8: E501 line too long (141 > 127 characters)


[error] 191-191: flake8: E222 multiple spaces after operator


[error] 192-192: flake8: E222 multiple spaces after operator


[error] 204-204: flake8: E231 missing whitespace after ','


[error] 206-206: flake8: E231 missing whitespace after ','

fla/ops/ttt/chunk.py (4)

97-105: Path for HEAD_FIRST vs. else path
Logic matches how you handle offsets/chunking throughout. Just keep a watch on line length if lines expand further.


239-340: Backward kernel extended to handle hb/hb0
A large chunk of changes. The approach carefully mirrors the forward logic. The final updates to b_hb are consistent. Watch out for flake8 line-length warnings (E501) around lines 332.

🧰 Tools
🪛 GitHub Actions: pr

[error] 332-332: flake8: E127 continuation line over-indented for visual indent


897-923: Detailed backward kernel call for partial results
Storing v_new, x, y, rstd is consistent with the new design. Potential line-length issues remain but logic is fine.


1382-1382: Exceeding line length (E501) in forward function signature
Consider wrapping the arguments to fix flake8 errors:

- def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, initial_state_bias, output_final_state, offsets, head_first):
+ def forward(
+     ctx, q, k, v, w, b, BT, eta, scale, eps,
+     initial_state, initial_state_bias,
+     output_final_state, offsets, head_first
+ ):
🧰 Tools
🪛 GitHub Actions: pr

[error] 1382-1382: flake8: E501 line too long (133 > 127 characters)

fla/ops/ttt/fused_chunk.py (5)

4-4: Remove the unused import.

The static analysis tools flag typing.Tuple as unused. Removing it will address the flake8 error (F401).

-from typing import Optional, Tuple
+from typing import Optional
🧰 Tools
🪛 Ruff (0.8.2)

4-4: typing.Tuple imported but unused

Remove unused import: typing.Tuple

(F401)

🪛 GitHub Actions: pr

[error] 4-4: flake8: F401 'typing.Tuple' imported but unused


192-192: Remove the unused variable.

eos is assigned but never used. Removing this assignment will resolve the flake8 error (F841).

-    bos, eos = i_n * T, i_n * T + T
+    bos = i_n * T
🧰 Tools
🪛 Ruff (0.8.2)

192-192: Local variable eos is assigned to but never used

Remove assignment to unused variable eos

(F841)

🪛 GitHub Actions: pr

[error] 192-192: flake8: F841 local variable 'eos' is assigned to but never used


335-335: Remove the unused variable again.

Here as well, eos is never referenced after assignment. Removing it prevents confusion and addresses the static analysis warning.

-    bos, eos = i_n * T, i_n * T + T
+    bos = i_n * T
🧰 Tools
🪛 Ruff (0.8.2)

335-335: Local variable eos is assigned to but never used

Remove assignment to unused variable eos

(F841)


757-778: Consider refactoring repeated backward logic.

The backward pass code is quite extensive. If there’s repeated logic between fused_chunk_ttt_linear_bwd_h and fused_chunk_ttt_linear_bwd_dh, consider extracting it into shared helper functions to improve maintainability and reduce duplication.


865-871: Evaluate variable-length support in head-first mode.

Currently, an exception is raised when head_first is True and cu_seqlens is provided. If you plan to support variable-length inputs in the future for head-first mode, consider consolidating or documenting the differences needed to enable it.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 80e5b0c and 91db479.

📒 Files selected for processing (5)
  • benchmarks/ops/benchmark_ttt.py (3 hunks)
  • fla/ops/ttt/chunk.py (45 hunks)
  • fla/ops/ttt/fused_chunk.py (3 hunks)
  • fla/ops/ttt/naive.py (6 hunks)
  • tests/ops/test_ttt.py (13 hunks)
🧰 Additional context used
🪛 GitHub Actions: pr
benchmarks/ops/benchmark_ttt.py

[error] 10-10: flake8: F401 'fla.ops.ttt.naive.chunk_ttt_linear_ref' imported but unused


[error] 24-24: flake8: E302 expected 2 blank lines, found 1


[error] 66-66: flake8: E501 line too long (143 > 127 characters)


[error] 77-77: flake8: E501 line too long (143 > 127 characters)


[error] 90-90: flake8: E501 line too long (143 > 127 characters)

tests/ops/test_ttt.py

[error] 83-83: flake8: E222 multiple spaces after operator


[error] 84-84: flake8: E501 line too long (141 > 127 characters)


[error] 100-100: flake8: E222 multiple spaces after operator


[error] 101-101: flake8: E222 multiple spaces after operator


[error] 101-101: flake8: E501 line too long (142 > 127 characters)


[error] 113-113: flake8: E231 missing whitespace after ','


[error] 115-115: flake8: E231 missing whitespace after ','


[error] 174-174: flake8: E222 multiple spaces after operator


[error] 175-175: flake8: E501 line too long (141 > 127 characters)


[error] 191-191: flake8: E222 multiple spaces after operator


[error] 192-192: flake8: E222 multiple spaces after operator


[error] 204-204: flake8: E231 missing whitespace after ','


[error] 206-206: flake8: E231 missing whitespace after ','

fla/ops/ttt/chunk.py

[error] 121-121: flake8: E127 continuation line over-indented for visual indent


[error] 332-332: flake8: E127 continuation line over-indented for visual indent


[error] 1382-1382: flake8: E501 line too long (133 > 127 characters)

fla/ops/ttt/fused_chunk.py

[error] 4-4: flake8: F401 'typing.Tuple' imported but unused


[error] 192-192: flake8: F841 local variable 'eos' is assigned to but never used

🪛 Ruff (0.8.2)
fla/ops/ttt/fused_chunk.py

4-4: typing.Tuple imported but unused

Remove unused import: typing.Tuple

(F401)


192-192: Local variable eos is assigned to but never used

Remove assignment to unused variable eos

(F841)


335-335: Local variable eos is assigned to but never used

Remove assignment to unused variable eos

(F841)

🔇 Additional comments (58)
benchmarks/ops/benchmark_ttt.py (3)

9-9: Import for fused_chunk_ttt_linear is properly utilized
This addition is coherent with the subsequent usage of fused_chunk_ttt_linear.


41-41: Expanded methods list
Including "fused_chunk_ttt_linear" ensures coverage of the new fused operation in benchmarks. No issues detected.


81-85: Chunk size parameter rename is consistent
Switching from BT to chunk_size improves clarity. Ensure the same naming convention is applied consistently throughout the codebase.

tests/ops/test_ttt.py (17)

31-31: Added 16 to the list of T values for test parametrization
It’s good to test smaller sequences. Keep an eye on potential edge cases at T=16.


55-55: New hb0 definition
Storing bias in hb0 ensures distinct coverage for initial state bias. No immediate issues here.


64-64: Repeated definition of hb0
It's consistent with the scenario where head_first=False. The approach mirrors line 55. No concerns.


66-66: Consolidated .requires_grad_(True) calls
The map(lambda x: x.cuda().requires_grad_(True), ...) pattern is consistent across parameters, simplifying code.


69-69: Added dhbt for bias gradient testing
Ensures the new bias term is fully tested in backward passes.


71-82: Valid incorporation of initial_state_bias into chunk_ttt_linear
Properly passing initial_state_bias=hb0.clone() ensures the test covers new bias functionality. Code changes appear correct.


101-102: Backward pass for reference output
The usage of an additional backward for ref ensures parity with the new function. Looks correct.

🧰 Tools
🪛 GitHub Actions: pr

[error] 101-101: flake8: E222 multiple spaces after operator


[error] 101-101: flake8: E501 line too long (142 > 127 characters)


104-118: Extended output and gradient assertions
Testing ref vs. tri for all components, including hbt, is crucial for verifying correctness of the new bias term. This thoroughness is commendable.

🧰 Tools
🪛 GitHub Actions: pr

[error] 113-113: flake8: E231 missing whitespace after ','


[error] 115-115: flake8: E231 missing whitespace after ','


121-121: Parametrization repeated for T
Including [16, 30, 32, 63, 64, 256] again in another test ensures comprehensive coverage.


146-147: Preparing hb0 for fused_chunk test
Same pattern as the chunk test. The approach is consistent. No issues.


213-216: Additional test parameters for chunk_varlen_fwd
Expanding coverage for wide T range. Looks consistent with other tests.


243-244: Adding hb0 in variable-length scenario
Ensures the new bias dimension is tested for chunked variable lengths. Good coverage.


246-246: Integrating bias in chunk_ttt_linear call
Ensuring the function is tested under real conditions. No immediate issues.


263-281: Reference chunk_ttt_linear_ref logic for varlen
Maintaining the same logic for reference ensures consistent testing. No functional concerns.


283-283: Concatenating ref_hbt
Combining bias states from each chunk is correct for final comparative checks.


286-287: Asserting final states for varlen scenario
Ensures new bias dimension is validated end-to-end. No concerns.


3-3:

✅ Verification successful

Ensure os import is necessary
You introduced import os. Confirm you are actually using it in this file to avoid unused imports.


🏁 Script executed:

#!/bin/bash
rg --context 3 "os\." tests/ops/test_ttt.py

Length of output: 265


Verified: os Import is Necessary

The os module is used within the file—specifically in the line setting os.environ['TRITON_F32_DEFAULT'] = 'ieee'—so the import should be retained.

  • Verified usage in tests/ops/test_ttt.py confirming that os.environ is used.
fla/ops/ttt/naive.py (8)

36-36: Conditional initialization of hb
Using the provided bias if available is valid. Good approach. Confirm that shape assumptions always match the function logic.


43-43: Inclusion of bias in kh
Adding + hb modifies the state. This is functionally correct for the new bias dimension but watch out for potential dimension mismatches if shapes ever differ.


69-69: Returning hb only if output_final_state is True
This matches the handling of h. Maintains consistent design for optional final states.


84-84: Added initial_state_bias to chunk_ttt_linear_ref
This ensures that the reference path matches the main function’s new behavior.


108-108: Casting parameters to float32
Ensures numeric stability for computations. This is consistent with the rest of the code.


109-122: Calling ttt_linear with new bias argument
Fully integrates the new parameter. The code flows logically.


126-126: Returning the new final_state_bias
Completes the reference function’s alignment with the new interface. Looks good.


19-19:

❓ Verification inconclusive

Added initial_state_bias parameter to function signature
This implements an extra bias term. Ensure all callers handle this new parameter properly.


🏁 Script executed:

#!/bin/bash
rg -A 4 "ttt_linear\(.*initial_state_bias"

Length of output: 44


Manual Verification Needed – Update Call Sites for initial_state_bias Parameter

It appears that the automated search did not return any occurrences of ttt_linear( with the new initial_state_bias parameter. Please review all call sites of ttt_linear to ensure they’re updated accordingly to pass the initial_state_bias argument.

  • File Under Review: fla/ops/ttt/naive.py (line 19)
  • Action: Confirm every invocation of ttt_linear has been updated to include the new initial_state_bias parameter.
fla/ops/ttt/chunk.py (27)

18-19: Heuristic for USE_INITIAL_STATE_B
Recognizing hb0 as a separate initial bias is a clean approach. This addition is consistent with the existing pattern for initial_state.


27-28: Auto-tune key updated to include BK, BV
Ensures Triton can optimize for dimension sizes relevant to K/V. No issues noted.


39-40: New argument hb in chunk_ttt_linear_fwd_kernel_h
Maintains parity with the introduction of hb0. This is a necessary extension to pass bias state around.


41-41: Passing hb0 to kernel
Completes the forward kernel’s ability to use an initial bias. Implementation is consistent with the new feature.


43-43: New hbt for storing final bias state
Mirrors the approach used for ht. This ensures that the final bias state is saved if needed.


75-76: Loading initial hb0 into b_hb
Essential logic for bias initialization. No concerns about shape usage here, but remain mindful of dimension alignment in future changes.


79-80: Conditional load of hb0
Ensures the kernel only uses hb0 if it’s present. Code remains safe for the no-bias scenario.


94-95: Storing updated h/hb blocks
Committing the partial chunk states here is consistent with the chunk-based approach.


109-111: Adding b_hb to b_kh
Incorporates the bias state in the matrix multiply result. This is the core of the new bias logic.


124-125: Subtracting bias portion in final state updates
b_hb -= tl.sum(...) is the correct approach for the bias’s backward-like update.


188-189: Added hb pointer in chunk_ttt_linear_fwd_kernel_o
Plumbs the bias pointer through the second forward kernel. Implementation is aligned with the overall approach.


226-230: Heuristics require adjustments for new bias parameters
Expanding USE_INITIAL_STATE_B is logical. The new code is consistent so far.


441-442: Grad for final bias state
dhbt is processed similarly to dht. This is correct for ensuring the gradient flows for the new bias dimension.


465-466: Consistent condition checks for bias usage
Double-check the logic to ensure it’s skipping or including bias gradient only when needed. Looks correct.


483-484: b_dhb used to accumulate partial bias gradients
This approach complements b_dh for the main state. No immediate concerns.


542-543: Properly subtracting partial bias gradients
Ensures that b_dv_new is updated with bias-specific details. Good consistency with earlier code.


576-578: Ensuring shape constraints with conditional
Truncation or zeroing out with tl.where is standard for dimension alignment in Triton kernels.


647-649: Inclusion of dhb pointer in the chunk offsets
We can confirm it’s updated in parallel with dh. This is a natural extension for the bias’s chunk-by-chunk updates.


691-694: Using b_e_last to manage bias updates
Implementation matches the approach for h. The logic is consistent in the backward path, though it’s quite intricate. Good job.


852-864: New function chunk_ttt_linear_bwd_h with bias
This function comprehensively extends the backward logic to handle the bias state. Keep verifying shapes in tests.


985-988: hb0 included in chunk_ttt_linear_bwd_norm signature
Ensures the backward pass can handle the bias initial state. Implementation looks correct.


990-991: Including dhbt
Extends final-state gradient to the bias. This is coherent with the rest of the backward changes.


1021-1022: Allocating dhb similarly to dh
Ensures each chunk can store the bias-state gradient. Good approach.


1025-1026: Allocating dhb0 for the initial bias gradient
Completes the symmetrical design for bias handling.


1335-1357: Combining dv_new, do, and new backward calls
The multi-step backward logic is well-structured. No major issues identified.


1359-1374: Merging dk results with dk2
Collects partial gradients from the final kernel pass, consistent with the existing pattern for main states.


1417-1418: Expanded backward signature for new bias gradients
Having dhbt in the signature ensures all relevant gradients are captured. Implementation is aligned with new features.

fla/ops/ttt/fused_chunk.py (3)

14-19: New heuristics look good.

The addition of flags like USE_INITIAL_STATE_B and USE_OFFSETS is well-structured and aligns with the extended functionality. No issues found here.


729-745: Verify documentation consistency for new parameters.

fused_chunk_ttt_linear_fwd now supports output_final_state, initial_state_bias and offsets. Ensure the docstring correctly reflects these parameters and their usage in variable-length scenarios.

Would you like a script to scan references to these parameters elsewhere in the codebase and confirm consistent usage?


894-895: Confirm norm+residual usage.

The call to norm_residual immediately after the fused forward pass modifies the output. Double-check that applying normalization and residual in this manner matches the intended architecture, especially if you plan to chain multiple fused operations.

@sustcsonglin sustcsonglin merged commit 71fe07f into fla-org:main Mar 15, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants