- 
                Notifications
    You must be signed in to change notification settings 
- Fork 280
[TTT] Update fused chunk ops and state bias term #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| WalkthroughThe changes enhance the TTT linear operations by adding a new fused function and incorporating an additional bias parameter for the initial state in multiple modules. The updates span benchmark scripts, core operation files, and tests. Function signatures in forward and backward kernels across the chunk, fused_chunk, and naive implementations now require extra parameters (such as  Changes
 Sequence Diagram(s)sequenceDiagram
    participant U as User
    participant F as FusedChunkTTTLinear
    participant FK as FusedChunkTTTLinearFwdKernel
    participant BK as FusedChunkTTTLinearBwdKernel
    U->>+F: Call fused_chunk_ttt_linear(q, k, v, w, b, eta, ...)
    F->>+FK: Execute forward pass with (h0, hb0, ht, hbt, offsets, ...)
    FK-->>-F: Return computed output & final state bias
    alt Backward Pass
        F->>+BK: Execute backward kernels to compute gradients
        BK-->>-F: Return gradients (including dhb0)
    end
    F-->>U: Return final outputs and gradients (if computed)
sequenceDiagram
    participant B as BenchmarkTest
    participant C as ChunkTTTMethod
    participant F as FusedChunkTTTMethod
    B->>+C: Call chunk_ttt_linear with chunk_size=16
    C-->>-B: Return output and timing ("time_b")
    B->>+F: Call fused_chunk_ttt_linear with updated parameters
    F-->>-B: Return output and timing ("time_f_b")
    B-->>B: Compare benchmark results
Poem
 Tip ⚡🧪 Multi-step agentic review comment chat (experimental)
 ✨ Finishing Touches
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit: 
 Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
 Other keywords and placeholders
 CodeRabbit Configuration File ( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (17)
benchmarks/ops/benchmark_ttt.py (2)
35-36: Keep consistent dimension options for headdim_vals
Removing128fromheaddim_valsaligns the benchmark to use a single dimension size. If you plan to benchmark higher dimensions in the future, consider re-adding them or moving them to a config-driven approach.
88-99: Benchmark coverage for fused_chunk_ttt_linear looks solid
This block properly tests forward/backward performance for the new fused function. However, note the pipeline log shows line-length (E501) issues around these lines (e.g., line 90). Consider wrapping or splitting the lines to adhere to style guidelines.- k = torch.nn.functional.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True) + k_t = torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype) + k = torch.nn.functional.normalize(k_t, p=2, dim=-1).requires_grad_(True)🧰 Tools
🪛 GitHub Actions: pr
[error] 90-90: flake8: E501 line too long (143 > 127 characters)
tests/ops/test_ttt.py (6)
84-86: Multiple sums in backward pass
Combining multiple sums into a single backward call is fine. Watch out for flake8 line-length or spacing issues (see pipeline logs).🧰 Tools
🪛 GitHub Actions: pr
[error] 84-84: flake8: E501 line too long (141 > 127 characters)
88-99: Reference implementation chunk_ttt_linear_ref updated with bias
Balanced approach to test reference vs. new function. Also note potential line-length warnings (E501).
122-122: Potential duplication?
Check if lines 121-122 overlap with the same test. If not necessary for a separate test, consider deduplicating.
125-125: Single scale value usage
Defining[0.1]is fine, but you can parametrize multiple scale values if needed for broader coverage.
157-161: Mapping everything to.cuda().requires_grad_(True)again
Uniform usage across code. Keep an eye on flake8 line length if these lines get longer.
162-210: Testing fused_chunk_ttt_linear with new bias parameter
The block thoroughly validates forward/backward for fused approach. Pipeline logs mention line-length warnings in this region; consider splitting lines to comply with style.🧰 Tools
🪛 GitHub Actions: pr
[error] 174-174: flake8: E222 multiple spaces after operator
[error] 175-175: flake8: E501 line too long (141 > 127 characters)
[error] 191-191: flake8: E222 multiple spaces after operator
[error] 192-192: flake8: E222 multiple spaces after operator
[error] 204-204: flake8: E231 missing whitespace after ','
[error] 206-206: flake8: E231 missing whitespace after ','
fla/ops/ttt/chunk.py (4)
97-105: Path for HEAD_FIRST vs. else path
Logic matches how you handle offsets/chunking throughout. Just keep a watch on line length if lines expand further.
239-340: Backward kernel extended to handle hb/hb0
A large chunk of changes. The approach carefully mirrors the forward logic. The final updates tob_hbare consistent. Watch out for flake8 line-length warnings (E501) around lines 332.🧰 Tools
🪛 GitHub Actions: pr
[error] 332-332: flake8: E127 continuation line over-indented for visual indent
897-923: Detailed backward kernel call for partial results
Storingv_new,x,y,rstdis consistent with the new design. Potential line-length issues remain but logic is fine.
1382-1382: Exceeding line length (E501) in forward function signature
Consider wrapping the arguments to fix flake8 errors:- def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, initial_state_bias, output_final_state, offsets, head_first): + def forward( + ctx, q, k, v, w, b, BT, eta, scale, eps, + initial_state, initial_state_bias, + output_final_state, offsets, head_first + ):🧰 Tools
🪛 GitHub Actions: pr
[error] 1382-1382: flake8: E501 line too long (133 > 127 characters)
fla/ops/ttt/fused_chunk.py (5)
4-4: Remove the unused import.The static analysis tools flag
typing.Tupleas unused. Removing it will address the flake8 error (F401).-from typing import Optional, Tuple +from typing import Optional🧰 Tools
🪛 Ruff (0.8.2)
4-4:
typing.Tupleimported but unusedRemove unused import:
typing.Tuple(F401)
🪛 GitHub Actions: pr
[error] 4-4: flake8: F401 'typing.Tuple' imported but unused
192-192: Remove the unused variable.
eosis assigned but never used. Removing this assignment will resolve the flake8 error (F841).- bos, eos = i_n * T, i_n * T + T + bos = i_n * T🧰 Tools
🪛 Ruff (0.8.2)
192-192: Local variable
eosis assigned to but never usedRemove assignment to unused variable
eos(F841)
🪛 GitHub Actions: pr
[error] 192-192: flake8: F841 local variable 'eos' is assigned to but never used
335-335: Remove the unused variable again.Here as well,
eosis never referenced after assignment. Removing it prevents confusion and addresses the static analysis warning.- bos, eos = i_n * T, i_n * T + T + bos = i_n * T🧰 Tools
🪛 Ruff (0.8.2)
335-335: Local variable
eosis assigned to but never usedRemove assignment to unused variable
eos(F841)
757-778: Consider refactoring repeated backward logic.The backward pass code is quite extensive. If there’s repeated logic between
fused_chunk_ttt_linear_bwd_handfused_chunk_ttt_linear_bwd_dh, consider extracting it into shared helper functions to improve maintainability and reduce duplication.
865-871: Evaluate variable-length support in head-first mode.Currently, an exception is raised when
head_firstis True andcu_seqlensis provided. If you plan to support variable-length inputs in the future for head-first mode, consider consolidating or documenting the differences needed to enable it.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
- benchmarks/ops/benchmark_ttt.py(3 hunks)
- fla/ops/ttt/chunk.py(45 hunks)
- fla/ops/ttt/fused_chunk.py(3 hunks)
- fla/ops/ttt/naive.py(6 hunks)
- tests/ops/test_ttt.py(13 hunks)
🧰 Additional context used
🪛 GitHub Actions: pr
benchmarks/ops/benchmark_ttt.py
[error] 10-10: flake8: F401 'fla.ops.ttt.naive.chunk_ttt_linear_ref' imported but unused
[error] 24-24: flake8: E302 expected 2 blank lines, found 1
[error] 66-66: flake8: E501 line too long (143 > 127 characters)
[error] 77-77: flake8: E501 line too long (143 > 127 characters)
[error] 90-90: flake8: E501 line too long (143 > 127 characters)
tests/ops/test_ttt.py
[error] 83-83: flake8: E222 multiple spaces after operator
[error] 84-84: flake8: E501 line too long (141 > 127 characters)
[error] 100-100: flake8: E222 multiple spaces after operator
[error] 101-101: flake8: E222 multiple spaces after operator
[error] 101-101: flake8: E501 line too long (142 > 127 characters)
[error] 113-113: flake8: E231 missing whitespace after ','
[error] 115-115: flake8: E231 missing whitespace after ','
[error] 174-174: flake8: E222 multiple spaces after operator
[error] 175-175: flake8: E501 line too long (141 > 127 characters)
[error] 191-191: flake8: E222 multiple spaces after operator
[error] 192-192: flake8: E222 multiple spaces after operator
[error] 204-204: flake8: E231 missing whitespace after ','
[error] 206-206: flake8: E231 missing whitespace after ','
fla/ops/ttt/chunk.py
[error] 121-121: flake8: E127 continuation line over-indented for visual indent
[error] 332-332: flake8: E127 continuation line over-indented for visual indent
[error] 1382-1382: flake8: E501 line too long (133 > 127 characters)
fla/ops/ttt/fused_chunk.py
[error] 4-4: flake8: F401 'typing.Tuple' imported but unused
[error] 192-192: flake8: F841 local variable 'eos' is assigned to but never used
🪛 Ruff (0.8.2)
fla/ops/ttt/fused_chunk.py
4-4: typing.Tuple imported but unused
Remove unused import: typing.Tuple
(F401)
192-192: Local variable eos is assigned to but never used
Remove assignment to unused variable eos
(F841)
335-335: Local variable eos is assigned to but never used
Remove assignment to unused variable eos
(F841)
🔇 Additional comments (58)
benchmarks/ops/benchmark_ttt.py (3)
9-9: Import for fused_chunk_ttt_linear is properly utilized
This addition is coherent with the subsequent usage offused_chunk_ttt_linear.
41-41: Expanded methods list
Including"fused_chunk_ttt_linear"ensures coverage of the new fused operation in benchmarks. No issues detected.
81-85: Chunk size parameter rename is consistent
Switching fromBTtochunk_sizeimproves clarity. Ensure the same naming convention is applied consistently throughout the codebase.tests/ops/test_ttt.py (17)
31-31: Added16to the list of T values for test parametrization
It’s good to test smaller sequences. Keep an eye on potential edge cases at T=16.
55-55: Newhb0definition
Storing bias inhb0ensures distinct coverage for initial state bias. No immediate issues here.
64-64: Repeated definition ofhb0
It's consistent with the scenario wherehead_first=False. The approach mirrors line 55. No concerns.
66-66: Consolidated.requires_grad_(True)calls
Themap(lambda x: x.cuda().requires_grad_(True), ...)pattern is consistent across parameters, simplifying code.
69-69: Addeddhbtfor bias gradient testing
Ensures the new bias term is fully tested in backward passes.
71-82: Valid incorporation of initial_state_bias into chunk_ttt_linear
Properly passinginitial_state_bias=hb0.clone()ensures the test covers new bias functionality. Code changes appear correct.
101-102: Backward pass for reference output
The usage of an additional backward forrefensures parity with the new function. Looks correct.🧰 Tools
🪛 GitHub Actions: pr
[error] 101-101: flake8: E222 multiple spaces after operator
[error] 101-101: flake8: E501 line too long (142 > 127 characters)
104-118: Extended output and gradient assertions
Testingrefvs.trifor all components, includinghbt, is crucial for verifying correctness of the new bias term. This thoroughness is commendable.🧰 Tools
🪛 GitHub Actions: pr
[error] 113-113: flake8: E231 missing whitespace after ','
[error] 115-115: flake8: E231 missing whitespace after ','
121-121: Parametrization repeated forT
Including[16, 30, 32, 63, 64, 256]again in another test ensures comprehensive coverage.
146-147: Preparinghb0for fused_chunk test
Same pattern as the chunk test. The approach is consistent. No issues.
213-216: Additional test parameters for chunk_varlen_fwd
Expanding coverage for wide T range. Looks consistent with other tests.
243-244: Addinghb0in variable-length scenario
Ensures the new bias dimension is tested for chunked variable lengths. Good coverage.
246-246: Integrating bias in chunk_ttt_linear call
Ensuring the function is tested under real conditions. No immediate issues.
263-281: Reference chunk_ttt_linear_ref logic for varlen
Maintaining the same logic for reference ensures consistent testing. No functional concerns.
283-283: Concatenating ref_hbt
Combining bias states from each chunk is correct for final comparative checks.
286-287: Asserting final states for varlen scenario
Ensures new bias dimension is validated end-to-end. No concerns.
3-3:✅ Verification successful
Ensure
osimport is necessary
You introducedimport os. Confirm you are actually using it in this file to avoid unused imports.
🏁 Script executed:
#!/bin/bash rg --context 3 "os\." tests/ops/test_ttt.pyLength of output: 265
Verified:
osImport is NecessaryThe
osmodule is used within the file—specifically in the line settingos.environ['TRITON_F32_DEFAULT'] = 'ieee'—so the import should be retained.
- Verified usage in
tests/ops/test_ttt.pyconfirming thatos.environis used.fla/ops/ttt/naive.py (8)
36-36: Conditional initialization ofhb
Using the provided bias if available is valid. Good approach. Confirm that shape assumptions always match the function logic.
43-43: Inclusion of bias inkh
Adding+ hbmodifies the state. This is functionally correct for the new bias dimension but watch out for potential dimension mismatches if shapes ever differ.
69-69: Returninghbonly ifoutput_final_stateis True
This matches the handling ofh. Maintains consistent design for optional final states.
84-84: Addedinitial_state_biasto chunk_ttt_linear_ref
This ensures that the reference path matches the main function’s new behavior.
108-108: Casting parameters to float32
Ensures numeric stability for computations. This is consistent with the rest of the code.
109-122: Callingttt_linearwith new bias argument
Fully integrates the new parameter. The code flows logically.
126-126: Returning the new final_state_bias
Completes the reference function’s alignment with the new interface. Looks good.
19-19:❓ Verification inconclusive
Added
initial_state_biasparameter to function signature
This implements an extra bias term. Ensure all callers handle this new parameter properly.
🏁 Script executed:
#!/bin/bash rg -A 4 "ttt_linear\(.*initial_state_bias"Length of output: 44
Manual Verification Needed – Update Call Sites for
initial_state_biasParameterIt appears that the automated search did not return any occurrences of
ttt_linear(with the newinitial_state_biasparameter. Please review all call sites ofttt_linearto ensure they’re updated accordingly to pass theinitial_state_biasargument.
- File Under Review:
fla/ops/ttt/naive.py(line 19)- Action: Confirm every invocation of
ttt_linearhas been updated to include the newinitial_state_biasparameter.fla/ops/ttt/chunk.py (27)
18-19: Heuristic forUSE_INITIAL_STATE_B
Recognizinghb0as a separate initial bias is a clean approach. This addition is consistent with the existing pattern forinitial_state.
27-28: Auto-tune key updated to includeBK, BV
Ensures Triton can optimize for dimension sizes relevant to K/V. No issues noted.
39-40: New argumenthbin chunk_ttt_linear_fwd_kernel_h
Maintains parity with the introduction ofhb0. This is a necessary extension to pass bias state around.
41-41: Passinghb0to kernel
Completes the forward kernel’s ability to use an initial bias. Implementation is consistent with the new feature.
43-43: Newhbtfor storing final bias state
Mirrors the approach used forht. This ensures that the final bias state is saved if needed.
75-76: Loading initialhb0intob_hb
Essential logic for bias initialization. No concerns about shape usage here, but remain mindful of dimension alignment in future changes.
79-80: Conditional load ofhb0
Ensures the kernel only useshb0if it’s present. Code remains safe for the no-bias scenario.
94-95: Storing updated h/hb blocks
Committing the partial chunk states here is consistent with the chunk-based approach.
109-111: Addingb_hbtob_kh
Incorporates the bias state in the matrix multiply result. This is the core of the new bias logic.
124-125: Subtracting bias portion in final state updates
b_hb -= tl.sum(...)is the correct approach for the bias’s backward-like update.
188-189: Addedhbpointer in chunk_ttt_linear_fwd_kernel_o
Plumbs the bias pointer through the second forward kernel. Implementation is aligned with the overall approach.
226-230: Heuristics require adjustments for new bias parameters
ExpandingUSE_INITIAL_STATE_Bis logical. The new code is consistent so far.
441-442: Grad for final bias state
dhbtis processed similarly todht. This is correct for ensuring the gradient flows for the new bias dimension.
465-466: Consistent condition checks for bias usage
Double-check the logic to ensure it’s skipping or including bias gradient only when needed. Looks correct.
483-484:b_dhbused to accumulate partial bias gradients
This approach complementsb_dhfor the main state. No immediate concerns.
542-543: Properly subtracting partial bias gradients
Ensures thatb_dv_newis updated with bias-specific details. Good consistency with earlier code.
576-578: Ensuring shape constraints with conditional
Truncation or zeroing out withtl.whereis standard for dimension alignment in Triton kernels.
647-649: Inclusion ofdhbpointer in the chunk offsets
We can confirm it’s updated in parallel withdh. This is a natural extension for the bias’s chunk-by-chunk updates.
691-694: Usingb_e_lastto manage bias updates
Implementation matches the approach forh. The logic is consistent in the backward path, though it’s quite intricate. Good job.
852-864: New functionchunk_ttt_linear_bwd_hwith bias
This function comprehensively extends the backward logic to handle the bias state. Keep verifying shapes in tests.
985-988:hb0included in chunk_ttt_linear_bwd_norm signature
Ensures the backward pass can handle the bias initial state. Implementation looks correct.
990-991: Includingdhbt
Extends final-state gradient to the bias. This is coherent with the rest of the backward changes.
1021-1022: Allocatingdhbsimilarly todh
Ensures each chunk can store the bias-state gradient. Good approach.
1025-1026: Allocatingdhb0for the initial bias gradient
Completes the symmetrical design for bias handling.
1335-1357: Combining dv_new, do, and new backward calls
The multi-step backward logic is well-structured. No major issues identified.
1359-1374: Merging dk results withdk2
Collects partial gradients from the final kernel pass, consistent with the existing pattern for main states.
1417-1418: Expanded backward signature for new bias gradients
Havingdhbtin the signature ensures all relevant gradients are captured. Implementation is aligned with new features.fla/ops/ttt/fused_chunk.py (3)
14-19: New heuristics look good.The addition of flags like
USE_INITIAL_STATE_BandUSE_OFFSETSis well-structured and aligns with the extended functionality. No issues found here.
729-745: Verify documentation consistency for new parameters.
fused_chunk_ttt_linear_fwdnow supportsoutput_final_state,initial_state_biasand offsets. Ensure the docstring correctly reflects these parameters and their usage in variable-length scenarios.Would you like a script to scan references to these parameters elsewhere in the codebase and confirm consistent usage?
894-895: Confirm norm+residual usage.The call to
norm_residualimmediately after the fused forward pass modifies the output. Double-check that applying normalization and residual in this manner matches the intended architecture, especially if you plan to chain multiple fused operations.
Summary by CodeRabbit
New Features
Tests