Skip to content

Conversation

@yuweih205
Copy link
Contributor

@yuweih205 yuweih205 commented Jun 26, 2025

This PR updates 10 benchmark scripts to make them compatible with the latest interface of the FLA project.

✅ Changes:

  • Updated usage of triton.testing.do_bench() in:
    benchmark_abc.py, benchmark_based.py, benchmark_delta_rule.py, benchmark_fla.py, benchmark_gla.py, benchmark_gsa.py, benchmark_retention.py, benchmark_titans.py, benchmark_ttt.py, benchmark_nsa.py
  • Fixed argument mismatches and keyword errors in provider logic
  • Removed legacy benchmark formats incompatible with latest fla.ops.*

All modified benchmarks are verified to run successfully on the latest version of FLA + Triton.

Note:

  • During benchmarking, used_recurrent_gl and fused_recurrent_gsa encounter segmentation faults in their backward pass.
  • These faults occur consistently at the kernel execution stage and are reproducible even without benchmark wrappers.
  • This indicates a potential issue within their internal implementation rather than the benchmarking logic.

此 PR 使 10 个 benchmark 脚本在最新版 flash-linear-attention 中恢复可用。改动包括接口适配、参数兼容处理等,已在 H100 上通过验证。

Summary by CodeRabbit

  • Refactor

    • Standardized tensor shape ordering across multiple benchmark scripts to consistently use (batch, sequence, head, dimension) format for input tensors.
    • Updated auxiliary tensor shapes and argument passing to align with the new tensor ordering.
    • Adjusted specific benchmarks to restrict tested methods or update parameter shapes for improved consistency.
    • Improved clarity in argument passing by using keyword arguments in some benchmark calls.
  • Chores

    • Disabled or commented out certain benchmark cases and methods to streamline testing and output.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jun 26, 2025

Warning

Rate limit exceeded

@yzhangcs has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 23 minutes and 34 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between dfd22a2 and 7e1c78f.

📒 Files selected for processing (3)
  • benchmarks/ops/benchmark_fla.py (2 hunks)
  • benchmarks/ops/benchmark_gla.py (2 hunks)
  • benchmarks/ops/benchmark_gsa.py (1 hunks)

"""

Walkthrough

This change standardizes tensor dimension orderings across multiple benchmarking scripts for various attention and linear operations. Tensor initializations for queries, keys, values, and auxiliary variables are updated to consistently use the (B, T, H, D) shape, replacing previous variations. Some benchmarked methods are removed or commented out, and keyword argument usage is clarified for specific function calls.

Changes

File(s) Change Summary
benchmarks/ops/benchmark_abc.py
benchmarks/ops/benchmark_retention.py
benchmarks/ops/benchmark_gsa.py
Standardized tensor shapes for q, k, v, and auxiliary tensors to (B, T, H, D) across all providers.
benchmarks/ops/benchmark_based.py Refined provider-based branching for tensor shape selection; explicit shape control for each provider.
benchmarks/ops/benchmark_delta_rule.py
benchmarks/ops/benchmark_ttt.py
Changed tensor shape ordering from (B, H, T, D) to (B, T, H, D) for q, k, v, beta, eta, and related.
benchmarks/ops/benchmark_fla.py Unified tensor shapes across providers; corrected backward calls to use correct tensor outputs.
benchmarks/ops/benchmark_gla.py Conditional tensor shape initialization based on provider; minor changes to backward call arguments.
benchmarks/ops/benchmark_nsa.py Changed function calls to use keyword argument block_indices=indices instead of positional.
benchmarks/ops/benchmark_titans.py Updated w and b tensor shapes from (H, headdim) to (seqlen, headdim) for Titans linear operation.

Sequence Diagram(s)

sequenceDiagram
    participant BenchmarkScript
    participant TensorInit
    participant ProviderFunction

    BenchmarkScript->>TensorInit: Initialize q, k, v (B, T, H, D)
    alt Provider requires auxiliary tensor
        TensorInit->>TensorInit: Initialize g, s, f, beta, etc. (B, T, H, D/M)
    end
    BenchmarkScript->>ProviderFunction: Call benchmarked function with tensors
    ProviderFunction-->>BenchmarkScript: Return result(s)
    BenchmarkScript->>ProviderFunction: Call backward if needed
Loading

Possibly related PRs

Poem

In the land of tensors, shapes align,
(B, T, H, D) now by design.
Benchmarks run with order true,
Each provider knows just what to do.
Rabbits cheer as chaos ends,
For tidy shapes are coding friends!
🐇✨
"""

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate Unit Tests
  • Create PR with Unit Tests
  • Post Copyable Unit Tests in Comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (2)
benchmarks/ops/benchmark_delta_rule.py (1)

8-8: Remove unused import.

The fused_chunk_delta_rule import is no longer used since the method was removed from benchmarking.

-from fla.ops.delta_rule import chunk_delta_rule, fused_chunk_delta_rule
+from fla.ops.delta_rule import chunk_delta_rule
tests/ops/test_gsa.py (1)

97-97: Fix gradient assignment bug.

There's an incorrect variable assignment where s.grad is used instead of g.grad.

-    tri_dg, s.grad = g.grad.clone(), None
+    tri_dg, g.grad = g.grad.clone(), None
🧹 Nitpick comments (3)
.github/workflows/triton-builder.yml (1)

29-34: Consider adding error handling for API rate limits.

The GitHub API call for fetching the latest tag could potentially hit rate limits or fail.

Consider adding retry logic or fallback mechanisms:

-          latest_tag=$(curl -sL -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GH_TOKEN" https://api.github.com/repos/triton-lang/triton/tags | jq -r '.[0].name')
+          latest_tag=$(curl -sL --retry 3 --retry-delay 2 -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GH_TOKEN" https://api.github.com/repos/triton-lang/triton/tags | jq -r '.[0].name')
.github/workflows/reusable-ci-tests.yml (1)

52-78: Consider extracting Conda discovery to a shared action.

The Conda discovery logic is duplicated between the two jobs. This could be extracted to a composite action or simplified.

Consider creating a composite action for Conda discovery:

# .github/actions/setup-conda/action.yml
name: 'Setup Conda Environment'
inputs:
  conda_env_name:
    required: true
runs:
  using: 'composite'
  steps:
    - name: Discover Conda Path and Set Env Vars
      # ... your existing logic

Then use it in both jobs:

- name: Setup Conda
  uses: ./.github/actions/setup-conda
  with:
    conda_env_name: ${{ inputs.conda_env_name }}

Also applies to: 210-231

fla/ops/common/fused_recurrent.py (1)

59-59: Rename variable to avoid shadowing Python built-in.

The variable name all shadows the Python built-in function. While this works in Triton kernels, it's better to use a more descriptive name.

-    all = B * T
+    total_elements = B * T

And update all references to this variable accordingly (lines 71, 180, 251, 252, 256, 259, 263).

Also applies to: 180-180

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f26c819 and 49a06f1.

📒 Files selected for processing (23)
  • .github/workflows/intel-a770.yml (1 hunks)
  • .github/workflows/nvidia-4090.yml (1 hunks)
  • .github/workflows/nvidia-a100.yml (1 hunks)
  • .github/workflows/nvidia-h100.yml (1 hunks)
  • .github/workflows/reusable-build-triton.yml (1 hunks)
  • .github/workflows/reusable-ci-tests.yml (1 hunks)
  • .github/workflows/triton-builder.yml (1 hunks)
  • .github/workflows/triton-nightly.yml (1 hunks)
  • benchmarks/ops/benchmark_abc.py (1 hunks)
  • benchmarks/ops/benchmark_based.py (1 hunks)
  • benchmarks/ops/benchmark_delta_rule.py (3 hunks)
  • benchmarks/ops/benchmark_fla.py (2 hunks)
  • benchmarks/ops/benchmark_gla.py (2 hunks)
  • benchmarks/ops/benchmark_gsa.py (2 hunks)
  • benchmarks/ops/benchmark_nsa.py (1 hunks)
  • benchmarks/ops/benchmark_retention.py (1 hunks)
  • benchmarks/ops/benchmark_titans.py (1 hunks)
  • benchmarks/ops/benchmark_ttt.py (2 hunks)
  • fla/ops/common/fused_recurrent.py (15 hunks)
  • fla/ops/gsa/fused_recurrent.py (5 hunks)
  • tests/ops/test_gated_delta.py (1 hunks)
  • tests/ops/test_gla.py (8 hunks)
  • tests/ops/test_gsa.py (4 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (5)
benchmarks/ops/benchmark_nsa.py (1)
fla/ops/nsa/parallel.py (1)
  • parallel_nsa (766-881)
benchmarks/ops/benchmark_abc.py (1)
fla/modules/activations.py (1)
  • logsigmoid (157-158)
benchmarks/ops/benchmark_gla.py (1)
fla/modules/activations.py (1)
  • logsigmoid (157-158)
benchmarks/ops/benchmark_gsa.py (1)
fla/modules/activations.py (1)
  • logsigmoid (157-158)
tests/ops/test_gla.py (4)
fla/ops/gla/chunk.py (2)
  • chunk_gla (1217-1300)
  • backward (1195-1213)
fla/ops/gla/fused_recurrent.py (1)
  • fused_recurrent_gla (11-111)
fla/ops/gla/naive.py (1)
  • naive_recurrent_gla (12-41)
fla/utils.py (1)
  • assert_close (78-90)
🪛 Pylint (3.3.7)
tests/ops/test_gated_delta.py

[refactor] 153-153: Too many arguments (9/5)

(R0913)


[refactor] 153-153: Too many positional arguments (9/5)

(R0917)


[refactor] 153-153: Too many local variables (19/15)

(R0914)

benchmarks/ops/benchmark_fla.py

[error] 73-73: Using variable 'g' before assignment

(E0601)

tests/ops/test_gla.py

[refactor] 109-109: Too many local variables (30/15)

(R0914)


[refactor] 275-275: Too many local variables (27/15)

(R0914)

🪛 Flake8 (7.2.0)
benchmarks/ops/benchmark_gla.py

[error] 42-42: whitespace before ':'

(E203)

🪛 GitHub Actions: lint
benchmarks/ops/benchmark_gla.py

[error] 42-42: ESLint: whitespace before ':'. (E203)

benchmarks/ops/benchmark_delta_rule.py

[error] 8-8: ESLint: 'fla.ops.delta_rule.fused_chunk_delta_rule' imported but unused. (F401)

🪛 actionlint (1.7.7)
.github/workflows/triton-builder.yml

20-20: label "x64-docker" is unknown. available labels are "windows-latest", "windows-latest-8-cores", "windows-2025", "windows-2022", "windows-2019", "ubuntu-latest", "ubuntu-latest-4-cores", "ubuntu-latest-8-cores", "ubuntu-latest-16-cores", "ubuntu-24.04", "ubuntu-24.04-arm", "ubuntu-22.04", "ubuntu-22.04-arm", "ubuntu-20.04", "macos-latest", "macos-latest-xl", "macos-latest-xlarge", "macos-latest-large", "macos-15-xlarge", "macos-15-large", "macos-15", "macos-14-xl", "macos-14-xlarge", "macos-14-large", "macos-14", "macos-13-xl", "macos-13-xlarge", "macos-13-large", "macos-13", "self-hosted", "x64", "arm", "arm64", "linux", "macos", "windows". if it is a custom label for self-hosted runner, set list of labels in actionlint.yaml config file

(runner-label)

⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (44)
benchmarks/ops/benchmark_retention.py (1)

39-41: LGTM! Tensor dimension standardization improves consistency.

The standardization of tensor shapes to (B, T, H, D) across all providers eliminates conditional complexity and aligns with the unified tensor layout conventions being adopted across the benchmark suite.

benchmarks/ops/benchmark_titans.py (1)

62-63: Verify tensor shape compatibility with chunk_titans_linear_ref.

The tensor shapes for w and b have been changed from (H, headdim) to (seqlen, headdim). Please ensure this change is compatible with the expected parameter shapes in the chunk_titans_linear_ref function.

#!/bin/bash
# Description: Verify the expected parameter shapes for chunk_titans_linear_ref function

# Search for the function definition to check parameter expectations
ast-grep --pattern $'def chunk_titans_linear_ref($$$) {
  $$$
}'
benchmarks/ops/benchmark_based.py (1)

50-57: LGTM! Improved provider-specific tensor shape handling.

The refined logic appropriately separates providers that require (B, H, T, D) tensor layout from those using the standardized (B, T, H, D) layout, maintaining compatibility while moving toward consistency.

.github/workflows/reusable-build-triton.yml (3)

98-101: LGTM! URL patching with verification.

The script correctly patches the LLVM mirror URL and includes verification to ensure the replacement succeeded. The fallback to different file paths (python/setup.py vs setup.py) is a good defensive practice.


156-189: Robust error handling in nightly build processing.

The nightly build wheel processing includes proper error handling with:

  • Subshell execution to isolate errors
  • Comprehensive cleanup of temporary directories
  • All-or-nothing upload strategy based on processing success

This prevents partial uploads of corrupted packages.


194-220: Secure and robust package upload process.

The upload logic includes proper safeguards:

  • Verification that wheels exist before upload
  • Individual file upload with error checking
  • Secure credential handling via secrets
  • Clear error messaging for debugging
benchmarks/ops/benchmark_nsa.py (1)

55-55: LGTM! Improved code clarity with explicit keyword arguments.

Converting from positional to keyword arguments (block_indices=indices) improves code readability and reduces the risk of argument order errors. This aligns with the function signature in fla/ops/nsa/parallel.py.

Also applies to: 60-60

tests/ops/test_gated_delta.py (1)

153-153: Function rename aligns with naming conventions.

The rename from test_recurrent_forward to test_fused_recurrent improves consistency with the broader codebase refactoring toward fused operations.

benchmarks/ops/benchmark_abc.py (2)

45-47: Tensor shape standardization improves consistency.

The change from conditional tensor shapes to a unified (B, T, H, D) format across all providers eliminates complexity and aligns with the updated FLA benchmark interface.


50-53: Auxiliary tensor shapes standardized appropriately.

The auxiliary tensors g and s are correctly updated to match the new dimension ordering (B, T, H, D) and (B, T, H, M) respectively.

benchmarks/ops/benchmark_gsa.py (3)

47-49: Tensor shape standardization matches interface updates.

The tensor shape change from (B, H, T, D) to (B, T, H, D) aligns with the FLA benchmark interface refactoring and maintains consistency across benchmark files.


51-54: Auxiliary tensor shapes updated consistently.

The shapes for tensors f and g are correctly updated to (B, T, H, M) and (B, T, H, D) respectively, maintaining consistency with the main tensor shape changes.


66-67: Verify if gsa_recurrent_bwd should be permanently removed.

The gsa_recurrent_bwd benchmark is commented out, which might be temporary for debugging or indicate it's incompatible with the new interface.

Please clarify if this benchmark should be permanently removed or if it's temporarily disabled. If permanently removed, consider updating the line_vals and line_names arrays to remove the corresponding entries.

benchmarks/ops/benchmark_delta_rule.py (2)

38-38: Method list appropriately reduced.

Limiting benchmarking to chunk_delta_rule only is consistent with the commented out fused_chunk_delta_rule benchmarking code below.


50-53: Tensor shapes standardized correctly.

The tensor shape change from (B, H, seqlen, headdim) to (B, seqlen, H, headdim) maintains consistency with the interface updates across benchmark files.

benchmarks/ops/benchmark_ttt.py (4)

57-60: Tensor shapes standardized for chunk_gla method.

The tensor shape change from (B, H, seqlen, headdim) to (B, seqlen, H, headdim) aligns with the FLA benchmark interface refactoring.


68-71: Tensor shapes standardized for chunk_delta_rule method.

The shape updates maintain consistency with the new interface requirements across all tensor inputs.


79-82: Tensor shapes standardized for chunk_ttt_linear method.

The tensor shape changes are applied consistently to maintain compatibility with the updated interface.


92-97: Tensor shapes standardized for fused_chunk_ttt_linear method.

Both the main tensors and the eta parameter are correctly updated to use the new dimension ordering (B, seqlen, H, headdim) and (B, seqlen, H, 1) respectively.

benchmarks/ops/benchmark_fla.py (1)

46-58: LGTM! Tensor shape standardization is consistent.

The tensor shape changes properly standardize dimension ordering across different providers, aligning with the PR's objective to adapt to the latest FLA benchmark interface.

.github/workflows/intel-a770.yml (1)

18-27: Excellent CI workflow modularization!

The refactoring to use a reusable workflow improves maintainability and reduces code duplication across different hardware configurations. The parameters correctly configure the Intel A770 environment.

.github/workflows/nvidia-4090.yml (1)

18-27: Consistent CI workflow refactoring.

The workflow refactoring maintains consistency with other GPU workflows, properly configuring the NVIDIA 4090 environment through the reusable workflow parameters.

fla/ops/gsa/fused_recurrent.py (2)

155-155: LGTM! Consistent block size calculation.

The block size calculation for the M dimension now uses the same pattern as K and V dimensions, ensuring consistency across all dimensions.


256-257: Clean gradient accumulation refactoring.

The removal of cumulative sum operations in favor of direct gradient tensor summation simplifies the backward pass implementation. This change aligns with the kernel refactoring and improves code clarity.

Also applies to: 266-276, 297-297, 303-304, 313-323, 344-348

tests/ops/test_gsa.py (2)

30-46: LGTM: Parameter consolidation improves test readability.

The consolidation of multiple pytest.mark.parametrize decorators into a single decorator with explicit test tuples is a good improvement for readability and maintainability.


66-66: Good: Enhanced gradient testing with gate_logit_normalizer.

The addition of the gate_logit_normalizer parameter and the corresponding gradient tensors dhkt and dhvt aligns well with the broader refactoring to standardize gradient handling across fused recurrent operations.

Also applies to: 70-72

.github/workflows/triton-builder.yml (2)

19-24: Well-structured workflow with good separation of concerns.

The two-job design (prepare + execute) with dynamic tag discovery and matrix generation is well-architected. The conditional matrix building logic correctly handles the architecture selection.

Also applies to: 36-61


20-20: Address custom runner label warnings.

The static analysis tool flagged unknown runner labels. These appear to be custom self-hosted runner labels.

Ensure that the custom runner labels x64-docker and aarch64-docker are properly configured in your self-hosted runner setup. Consider adding an actionlint.yaml config file to define these custom labels if they're intentional.

# actionlint.yaml
runner-labels:
  - "x64-docker"
  - "aarch64-docker"

Also applies to: 42-43

.github/workflows/nvidia-a100.yml (1)

18-27: Excellent refactoring to use reusable workflow.

The simplification from detailed inline jobs to a parameterized reusable workflow significantly improves maintainability and consistency across different GPU platforms. The parameters are appropriate for A100 testing.

.github/workflows/nvidia-h100.yml (1)

18-51: Excellent test coverage across PyTorch versions.

The refactoring to use reusable workflows while maintaining comprehensive test coverage across different PyTorch versions (2.7, nightly, 2.6) is well-executed. The consistent use of skip_gpu_check: true for H100 runners suggests intentional configuration for this hardware.

.github/workflows/reusable-ci-tests.yml (4)

1-37: Well-designed reusable workflow with comprehensive parameters.

The input parameter design is thorough and provides good flexibility for different GPU types, PyTorch versions, and configurations. This centralizes CI logic effectively across multiple workflows.


133-164: Robust dependency installation logic.

The conditional installation logic for different GPU types (NVIDIA vs Intel) and PyTorch versions (stable vs nightly) is well-implemented. The special handling for H100 with causal-conv1d and flash-attn installations is appropriate.


182-191: Good error handling for varlen tests.

The approach of running varlen tests separately with non-critical failure handling (|| echo "Varlen tests failed...") is a good practice for handling known flaky tests while still executing them.

Also applies to: 274-283


149-149: Verify H100 runner condition.

The condition checks for nvidia-h100 but the actual runners used in the workflows are nvidia-h100-1, nvidia-h100-2, etc.

Verify that this condition correctly matches the H100 runners. Consider using a pattern match:

-              if [ "${{ inputs.runner }}" = "nvidia-h100" ]; then
+              if [[ "${{ inputs.runner }}" == nvidia-h100* ]]; then

Or check if this should be based on a different parameter like gpu_type and a specific flag.

.github/workflows/triton-nightly.yml (3)

7-24: LGTM! Well-structured workflow dispatch inputs.

The manual dispatch inputs provide good flexibility for testing specific Python versions and architectures.


31-50: LGTM! Clean job configuration with appropriate timeouts.

The conditional logic correctly handles scheduled vs manual dispatch runs, and the architecture-specific timeouts (120 min for x86_64, 720 min for aarch64) reflect the different build times.


88-103: LGTM! Proper use of dynamic matrix with the reusable workflow.

The job correctly uses the dynamically generated matrix and passes all required parameters to the reusable workflow.

tests/ops/test_gla.py (4)

45-49: Consider the impact of changing from normal to uniform distribution.

The change from torch.randn to torch.rand switches the input distribution from normal (mean=0, std=1) to uniform [0,1). This could affect test coverage since:

  • Normal distribution includes negative values while uniform [0,1) doesn't
  • The value ranges and statistical properties are different

Was this change intentional? If you need positive-only values, consider using torch.randn(...).abs() to maintain similar statistical properties while ensuring positivity.


92-174: LGTM! Well-implemented variable-length sequence test.

The test correctly handles variable-length sequences by:

  • Computing reference outputs segment by segment
  • Properly indexing initial states for each segment
  • Including comprehensive gradient checks

The static analysis warning about too many local variables can be safely ignored as it's common in test functions that need to track multiple intermediate values.


176-258: LGTM! Comprehensive chunk implementation test.

The test properly validates the chunk implementation against the fused recurrent reference, with appropriate handling of the different parameter names (g for chunk_gla, gk for fused_recurrent_gla).


260-336: LGTM! Consistent variable-length test for chunk implementation.

The test follows the same robust pattern as test_fused_recurrent_varlen and provides good coverage for the chunk implementation with variable-length sequences.

fla/ops/common/fused_recurrent.py (3)

270-308: LGTM! Correct gradient computation implementation.

The gradient accumulation logic is mathematically sound:

  • Proper initialization from final hidden state gradients
  • Correct application of chain rule for gate gradients
  • The i_k == 0 condition (line 303) correctly ensures dgv accumulation happens only once across K-dimension blocks

352-352: Good optimization for block size calculation.

Using min(triton.next_power_of_2(K), 64) ensures optimal block sizes for Triton kernels while preventing excessively large blocks that could hurt performance.

Also applies to: 409-409


474-536: LGTM! Proper autograd function implementation.

The FusedRecurrentFunction correctly:

  • Saves all necessary tensors for backward pass
  • Handles the new gradient tensors (dg, dgk, dgv)
  • Returns gradients in the correct order matching forward inputs

@yzhangcs
Copy link
Member

@yuweih205 Nice job, could you fix the lint issues before merging.

@yuweih205 yuweih205 force-pushed the fix-benchmark-adapt branch from 08be6e4 to 85717eb Compare June 26, 2025 09:06
@yuweih205 yuweih205 force-pushed the fix-benchmark-adapt branch from 85717eb to dfd22a2 Compare June 26, 2025 10:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
benchmarks/ops/benchmark_fla.py (1)

73-73: Variable scope issue persists despite past review.

The variable g is only defined when "gla" in provider but is used when provider == 'gla_fused_chunk'. While the logic should work (since 'gla_fused_chunk' contains "gla"), static analysis tools flag this as an error, and it creates maintenance risk.

The past review comment on this issue remains valid. Consider the suggested fix to move the g usage inside the conditional block or ensure clearer variable scoping to avoid static analysis warnings and potential future bugs.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 85717eb and dfd22a2.

📒 Files selected for processing (10)
  • benchmarks/ops/benchmark_abc.py (1 hunks)
  • benchmarks/ops/benchmark_based.py (1 hunks)
  • benchmarks/ops/benchmark_delta_rule.py (4 hunks)
  • benchmarks/ops/benchmark_fla.py (2 hunks)
  • benchmarks/ops/benchmark_gla.py (2 hunks)
  • benchmarks/ops/benchmark_gsa.py (2 hunks)
  • benchmarks/ops/benchmark_nsa.py (1 hunks)
  • benchmarks/ops/benchmark_retention.py (1 hunks)
  • benchmarks/ops/benchmark_titans.py (1 hunks)
  • benchmarks/ops/benchmark_ttt.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (9)
  • benchmarks/ops/benchmark_nsa.py
  • benchmarks/ops/benchmark_retention.py
  • benchmarks/ops/benchmark_titans.py
  • benchmarks/ops/benchmark_gsa.py
  • benchmarks/ops/benchmark_abc.py
  • benchmarks/ops/benchmark_delta_rule.py
  • benchmarks/ops/benchmark_ttt.py
  • benchmarks/ops/benchmark_gla.py
  • benchmarks/ops/benchmark_based.py
🧰 Additional context used
🪛 Pylint (3.3.7)
benchmarks/ops/benchmark_fla.py

[error] 73-73: Using variable 'g' before assignment

(E0601)

🔇 Additional comments (4)
benchmarks/ops/benchmark_fla.py (4)

46-49: LGTM: Tensor shape standardization for "based" provider.

The tensor shapes have been correctly standardized to (B, T, H, 16) format, maintaining consistency with the overall refactoring effort.


56-58: LGTM: Tensor shape standardization for default case.

The tensor shapes have been correctly standardized to (B, T, H, D) format, aligning with the consistent dimension ordering adopted across the benchmark.


67-67: LGTM: Correct backward pass modification.

The addition of [0] to access the first element of the returned tuple before calling .backward() is correct and aligns with the updated function interface.


69-69: LGTM: Correct backward pass modification.

The addition of [0] to access the first element of the returned tuple before calling .backward() is correct and aligns with the updated function interface.

@yzhangcs yzhangcs merged commit 3169ff3 into fla-org:main Jun 26, 2025
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants