-
Couldn't load subscription status.
- Fork 279
Refactor benchmark: adapt to latest FLA benchmark interface #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Warning Rate limit exceeded@yzhangcs has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 23 minutes and 34 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (3)
""" WalkthroughThis change standardizes tensor dimension orderings across multiple benchmarking scripts for various attention and linear operations. Tensor initializations for queries, keys, values, and auxiliary variables are updated to consistently use the (B, T, H, D) shape, replacing previous variations. Some benchmarked methods are removed or commented out, and keyword argument usage is clarified for specific function calls. Changes
Sequence Diagram(s)sequenceDiagram
participant BenchmarkScript
participant TensorInit
participant ProviderFunction
BenchmarkScript->>TensorInit: Initialize q, k, v (B, T, H, D)
alt Provider requires auxiliary tensor
TensorInit->>TensorInit: Initialize g, s, f, beta, etc. (B, T, H, D/M)
end
BenchmarkScript->>ProviderFunction: Call benchmarked function with tensors
ProviderFunction-->>BenchmarkScript: Return result(s)
BenchmarkScript->>ProviderFunction: Call backward if needed
Possibly related PRs
Poem
✨ Finishing Touches
🧪 Generate Unit Tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🔭 Outside diff range comments (2)
benchmarks/ops/benchmark_delta_rule.py (1)
8-8: Remove unused import.The
fused_chunk_delta_ruleimport is no longer used since the method was removed from benchmarking.-from fla.ops.delta_rule import chunk_delta_rule, fused_chunk_delta_rule +from fla.ops.delta_rule import chunk_delta_ruletests/ops/test_gsa.py (1)
97-97: Fix gradient assignment bug.There's an incorrect variable assignment where
s.gradis used instead ofg.grad.- tri_dg, s.grad = g.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None
🧹 Nitpick comments (3)
.github/workflows/triton-builder.yml (1)
29-34: Consider adding error handling for API rate limits.The GitHub API call for fetching the latest tag could potentially hit rate limits or fail.
Consider adding retry logic or fallback mechanisms:
- latest_tag=$(curl -sL -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GH_TOKEN" https://api.github.com/repos/triton-lang/triton/tags | jq -r '.[0].name') + latest_tag=$(curl -sL --retry 3 --retry-delay 2 -H "Accept: application/vnd.github+json" -H "Authorization: Bearer $GH_TOKEN" https://api.github.com/repos/triton-lang/triton/tags | jq -r '.[0].name').github/workflows/reusable-ci-tests.yml (1)
52-78: Consider extracting Conda discovery to a shared action.The Conda discovery logic is duplicated between the two jobs. This could be extracted to a composite action or simplified.
Consider creating a composite action for Conda discovery:
# .github/actions/setup-conda/action.yml name: 'Setup Conda Environment' inputs: conda_env_name: required: true runs: using: 'composite' steps: - name: Discover Conda Path and Set Env Vars # ... your existing logicThen use it in both jobs:
- name: Setup Conda uses: ./.github/actions/setup-conda with: conda_env_name: ${{ inputs.conda_env_name }}Also applies to: 210-231
fla/ops/common/fused_recurrent.py (1)
59-59: Rename variable to avoid shadowing Python built-in.The variable name
allshadows the Python built-in function. While this works in Triton kernels, it's better to use a more descriptive name.- all = B * T + total_elements = B * TAnd update all references to this variable accordingly (lines 71, 180, 251, 252, 256, 259, 263).
Also applies to: 180-180
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (23)
.github/workflows/intel-a770.yml(1 hunks).github/workflows/nvidia-4090.yml(1 hunks).github/workflows/nvidia-a100.yml(1 hunks).github/workflows/nvidia-h100.yml(1 hunks).github/workflows/reusable-build-triton.yml(1 hunks).github/workflows/reusable-ci-tests.yml(1 hunks).github/workflows/triton-builder.yml(1 hunks).github/workflows/triton-nightly.yml(1 hunks)benchmarks/ops/benchmark_abc.py(1 hunks)benchmarks/ops/benchmark_based.py(1 hunks)benchmarks/ops/benchmark_delta_rule.py(3 hunks)benchmarks/ops/benchmark_fla.py(2 hunks)benchmarks/ops/benchmark_gla.py(2 hunks)benchmarks/ops/benchmark_gsa.py(2 hunks)benchmarks/ops/benchmark_nsa.py(1 hunks)benchmarks/ops/benchmark_retention.py(1 hunks)benchmarks/ops/benchmark_titans.py(1 hunks)benchmarks/ops/benchmark_ttt.py(2 hunks)fla/ops/common/fused_recurrent.py(15 hunks)fla/ops/gsa/fused_recurrent.py(5 hunks)tests/ops/test_gated_delta.py(1 hunks)tests/ops/test_gla.py(8 hunks)tests/ops/test_gsa.py(4 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (5)
benchmarks/ops/benchmark_nsa.py (1)
fla/ops/nsa/parallel.py (1)
parallel_nsa(766-881)
benchmarks/ops/benchmark_abc.py (1)
fla/modules/activations.py (1)
logsigmoid(157-158)
benchmarks/ops/benchmark_gla.py (1)
fla/modules/activations.py (1)
logsigmoid(157-158)
benchmarks/ops/benchmark_gsa.py (1)
fla/modules/activations.py (1)
logsigmoid(157-158)
tests/ops/test_gla.py (4)
fla/ops/gla/chunk.py (2)
chunk_gla(1217-1300)backward(1195-1213)fla/ops/gla/fused_recurrent.py (1)
fused_recurrent_gla(11-111)fla/ops/gla/naive.py (1)
naive_recurrent_gla(12-41)fla/utils.py (1)
assert_close(78-90)
🪛 Pylint (3.3.7)
tests/ops/test_gated_delta.py
[refactor] 153-153: Too many arguments (9/5)
(R0913)
[refactor] 153-153: Too many positional arguments (9/5)
(R0917)
[refactor] 153-153: Too many local variables (19/15)
(R0914)
benchmarks/ops/benchmark_fla.py
[error] 73-73: Using variable 'g' before assignment
(E0601)
tests/ops/test_gla.py
[refactor] 109-109: Too many local variables (30/15)
(R0914)
[refactor] 275-275: Too many local variables (27/15)
(R0914)
🪛 Flake8 (7.2.0)
benchmarks/ops/benchmark_gla.py
[error] 42-42: whitespace before ':'
(E203)
🪛 GitHub Actions: lint
benchmarks/ops/benchmark_gla.py
[error] 42-42: ESLint: whitespace before ':'. (E203)
benchmarks/ops/benchmark_delta_rule.py
[error] 8-8: ESLint: 'fla.ops.delta_rule.fused_chunk_delta_rule' imported but unused. (F401)
🪛 actionlint (1.7.7)
.github/workflows/triton-builder.yml
20-20: label "x64-docker" is unknown. available labels are "windows-latest", "windows-latest-8-cores", "windows-2025", "windows-2022", "windows-2019", "ubuntu-latest", "ubuntu-latest-4-cores", "ubuntu-latest-8-cores", "ubuntu-latest-16-cores", "ubuntu-24.04", "ubuntu-24.04-arm", "ubuntu-22.04", "ubuntu-22.04-arm", "ubuntu-20.04", "macos-latest", "macos-latest-xl", "macos-latest-xlarge", "macos-latest-large", "macos-15-xlarge", "macos-15-large", "macos-15", "macos-14-xl", "macos-14-xlarge", "macos-14-large", "macos-14", "macos-13-xl", "macos-13-xlarge", "macos-13-large", "macos-13", "self-hosted", "x64", "arm", "arm64", "linux", "macos", "windows". if it is a custom label for self-hosted runner, set list of labels in actionlint.yaml config file
(runner-label)
⏰ Context from checks skipped due to timeout of 90000ms (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (44)
benchmarks/ops/benchmark_retention.py (1)
39-41: LGTM! Tensor dimension standardization improves consistency.The standardization of tensor shapes to
(B, T, H, D)across all providers eliminates conditional complexity and aligns with the unified tensor layout conventions being adopted across the benchmark suite.benchmarks/ops/benchmark_titans.py (1)
62-63: Verify tensor shape compatibility with chunk_titans_linear_ref.The tensor shapes for
wandbhave been changed from(H, headdim)to(seqlen, headdim). Please ensure this change is compatible with the expected parameter shapes in thechunk_titans_linear_reffunction.#!/bin/bash # Description: Verify the expected parameter shapes for chunk_titans_linear_ref function # Search for the function definition to check parameter expectations ast-grep --pattern $'def chunk_titans_linear_ref($$$) { $$$ }'benchmarks/ops/benchmark_based.py (1)
50-57: LGTM! Improved provider-specific tensor shape handling.The refined logic appropriately separates providers that require
(B, H, T, D)tensor layout from those using the standardized(B, T, H, D)layout, maintaining compatibility while moving toward consistency..github/workflows/reusable-build-triton.yml (3)
98-101: LGTM! URL patching with verification.The script correctly patches the LLVM mirror URL and includes verification to ensure the replacement succeeded. The fallback to different file paths (
python/setup.pyvssetup.py) is a good defensive practice.
156-189: Robust error handling in nightly build processing.The nightly build wheel processing includes proper error handling with:
- Subshell execution to isolate errors
- Comprehensive cleanup of temporary directories
- All-or-nothing upload strategy based on processing success
This prevents partial uploads of corrupted packages.
194-220: Secure and robust package upload process.The upload logic includes proper safeguards:
- Verification that wheels exist before upload
- Individual file upload with error checking
- Secure credential handling via secrets
- Clear error messaging for debugging
benchmarks/ops/benchmark_nsa.py (1)
55-55: LGTM! Improved code clarity with explicit keyword arguments.Converting from positional to keyword arguments (
block_indices=indices) improves code readability and reduces the risk of argument order errors. This aligns with the function signature infla/ops/nsa/parallel.py.Also applies to: 60-60
tests/ops/test_gated_delta.py (1)
153-153: Function rename aligns with naming conventions.The rename from
test_recurrent_forwardtotest_fused_recurrentimproves consistency with the broader codebase refactoring toward fused operations.benchmarks/ops/benchmark_abc.py (2)
45-47: Tensor shape standardization improves consistency.The change from conditional tensor shapes to a unified
(B, T, H, D)format across all providers eliminates complexity and aligns with the updated FLA benchmark interface.
50-53: Auxiliary tensor shapes standardized appropriately.The auxiliary tensors
gandsare correctly updated to match the new dimension ordering(B, T, H, D)and(B, T, H, M)respectively.benchmarks/ops/benchmark_gsa.py (3)
47-49: Tensor shape standardization matches interface updates.The tensor shape change from
(B, H, T, D)to(B, T, H, D)aligns with the FLA benchmark interface refactoring and maintains consistency across benchmark files.
51-54: Auxiliary tensor shapes updated consistently.The shapes for tensors
fandgare correctly updated to(B, T, H, M)and(B, T, H, D)respectively, maintaining consistency with the main tensor shape changes.
66-67: Verify if gsa_recurrent_bwd should be permanently removed.The
gsa_recurrent_bwdbenchmark is commented out, which might be temporary for debugging or indicate it's incompatible with the new interface.Please clarify if this benchmark should be permanently removed or if it's temporarily disabled. If permanently removed, consider updating the
line_valsandline_namesarrays to remove the corresponding entries.benchmarks/ops/benchmark_delta_rule.py (2)
38-38: Method list appropriately reduced.Limiting benchmarking to
chunk_delta_ruleonly is consistent with the commented outfused_chunk_delta_rulebenchmarking code below.
50-53: Tensor shapes standardized correctly.The tensor shape change from
(B, H, seqlen, headdim)to(B, seqlen, H, headdim)maintains consistency with the interface updates across benchmark files.benchmarks/ops/benchmark_ttt.py (4)
57-60: Tensor shapes standardized for chunk_gla method.The tensor shape change from
(B, H, seqlen, headdim)to(B, seqlen, H, headdim)aligns with the FLA benchmark interface refactoring.
68-71: Tensor shapes standardized for chunk_delta_rule method.The shape updates maintain consistency with the new interface requirements across all tensor inputs.
79-82: Tensor shapes standardized for chunk_ttt_linear method.The tensor shape changes are applied consistently to maintain compatibility with the updated interface.
92-97: Tensor shapes standardized for fused_chunk_ttt_linear method.Both the main tensors and the eta parameter are correctly updated to use the new dimension ordering
(B, seqlen, H, headdim)and(B, seqlen, H, 1)respectively.benchmarks/ops/benchmark_fla.py (1)
46-58: LGTM! Tensor shape standardization is consistent.The tensor shape changes properly standardize dimension ordering across different providers, aligning with the PR's objective to adapt to the latest FLA benchmark interface.
.github/workflows/intel-a770.yml (1)
18-27: Excellent CI workflow modularization!The refactoring to use a reusable workflow improves maintainability and reduces code duplication across different hardware configurations. The parameters correctly configure the Intel A770 environment.
.github/workflows/nvidia-4090.yml (1)
18-27: Consistent CI workflow refactoring.The workflow refactoring maintains consistency with other GPU workflows, properly configuring the NVIDIA 4090 environment through the reusable workflow parameters.
fla/ops/gsa/fused_recurrent.py (2)
155-155: LGTM! Consistent block size calculation.The block size calculation for the M dimension now uses the same pattern as K and V dimensions, ensuring consistency across all dimensions.
256-257: Clean gradient accumulation refactoring.The removal of cumulative sum operations in favor of direct gradient tensor summation simplifies the backward pass implementation. This change aligns with the kernel refactoring and improves code clarity.
Also applies to: 266-276, 297-297, 303-304, 313-323, 344-348
tests/ops/test_gsa.py (2)
30-46: LGTM: Parameter consolidation improves test readability.The consolidation of multiple
pytest.mark.parametrizedecorators into a single decorator with explicit test tuples is a good improvement for readability and maintainability.
66-66: Good: Enhanced gradient testing with gate_logit_normalizer.The addition of the
gate_logit_normalizerparameter and the corresponding gradient tensorsdhktanddhvtaligns well with the broader refactoring to standardize gradient handling across fused recurrent operations.Also applies to: 70-72
.github/workflows/triton-builder.yml (2)
19-24: Well-structured workflow with good separation of concerns.The two-job design (prepare + execute) with dynamic tag discovery and matrix generation is well-architected. The conditional matrix building logic correctly handles the architecture selection.
Also applies to: 36-61
20-20: Address custom runner label warnings.The static analysis tool flagged unknown runner labels. These appear to be custom self-hosted runner labels.
Ensure that the custom runner labels
x64-dockerandaarch64-dockerare properly configured in your self-hosted runner setup. Consider adding anactionlint.yamlconfig file to define these custom labels if they're intentional.# actionlint.yaml runner-labels: - "x64-docker" - "aarch64-docker"Also applies to: 42-43
.github/workflows/nvidia-a100.yml (1)
18-27: Excellent refactoring to use reusable workflow.The simplification from detailed inline jobs to a parameterized reusable workflow significantly improves maintainability and consistency across different GPU platforms. The parameters are appropriate for A100 testing.
.github/workflows/nvidia-h100.yml (1)
18-51: Excellent test coverage across PyTorch versions.The refactoring to use reusable workflows while maintaining comprehensive test coverage across different PyTorch versions (2.7, nightly, 2.6) is well-executed. The consistent use of
skip_gpu_check: truefor H100 runners suggests intentional configuration for this hardware..github/workflows/reusable-ci-tests.yml (4)
1-37: Well-designed reusable workflow with comprehensive parameters.The input parameter design is thorough and provides good flexibility for different GPU types, PyTorch versions, and configurations. This centralizes CI logic effectively across multiple workflows.
133-164: Robust dependency installation logic.The conditional installation logic for different GPU types (NVIDIA vs Intel) and PyTorch versions (stable vs nightly) is well-implemented. The special handling for H100 with causal-conv1d and flash-attn installations is appropriate.
182-191: Good error handling for varlen tests.The approach of running varlen tests separately with non-critical failure handling (
|| echo "Varlen tests failed...") is a good practice for handling known flaky tests while still executing them.Also applies to: 274-283
149-149: Verify H100 runner condition.The condition checks for
nvidia-h100but the actual runners used in the workflows arenvidia-h100-1,nvidia-h100-2, etc.Verify that this condition correctly matches the H100 runners. Consider using a pattern match:
- if [ "${{ inputs.runner }}" = "nvidia-h100" ]; then + if [[ "${{ inputs.runner }}" == nvidia-h100* ]]; thenOr check if this should be based on a different parameter like
gpu_typeand a specific flag..github/workflows/triton-nightly.yml (3)
7-24: LGTM! Well-structured workflow dispatch inputs.The manual dispatch inputs provide good flexibility for testing specific Python versions and architectures.
31-50: LGTM! Clean job configuration with appropriate timeouts.The conditional logic correctly handles scheduled vs manual dispatch runs, and the architecture-specific timeouts (120 min for x86_64, 720 min for aarch64) reflect the different build times.
88-103: LGTM! Proper use of dynamic matrix with the reusable workflow.The job correctly uses the dynamically generated matrix and passes all required parameters to the reusable workflow.
tests/ops/test_gla.py (4)
45-49: Consider the impact of changing from normal to uniform distribution.The change from
torch.randntotorch.randswitches the input distribution from normal (mean=0, std=1) to uniform [0,1). This could affect test coverage since:
- Normal distribution includes negative values while uniform [0,1) doesn't
- The value ranges and statistical properties are different
Was this change intentional? If you need positive-only values, consider using
torch.randn(...).abs()to maintain similar statistical properties while ensuring positivity.
92-174: LGTM! Well-implemented variable-length sequence test.The test correctly handles variable-length sequences by:
- Computing reference outputs segment by segment
- Properly indexing initial states for each segment
- Including comprehensive gradient checks
The static analysis warning about too many local variables can be safely ignored as it's common in test functions that need to track multiple intermediate values.
176-258: LGTM! Comprehensive chunk implementation test.The test properly validates the chunk implementation against the fused recurrent reference, with appropriate handling of the different parameter names (
gfor chunk_gla,gkfor fused_recurrent_gla).
260-336: LGTM! Consistent variable-length test for chunk implementation.The test follows the same robust pattern as
test_fused_recurrent_varlenand provides good coverage for the chunk implementation with variable-length sequences.fla/ops/common/fused_recurrent.py (3)
270-308: LGTM! Correct gradient computation implementation.The gradient accumulation logic is mathematically sound:
- Proper initialization from final hidden state gradients
- Correct application of chain rule for gate gradients
- The
i_k == 0condition (line 303) correctly ensuresdgvaccumulation happens only once across K-dimension blocks
352-352: Good optimization for block size calculation.Using
min(triton.next_power_of_2(K), 64)ensures optimal block sizes for Triton kernels while preventing excessively large blocks that could hurt performance.Also applies to: 409-409
474-536: LGTM! Proper autograd function implementation.The
FusedRecurrentFunctioncorrectly:
- Saves all necessary tensors for backward pass
- Handles the new gradient tensors (dg, dgk, dgv)
- Returns gradients in the correct order matching forward inputs
|
@yuweih205 Nice job, could you fix the lint issues before merging. |
08be6e4 to
85717eb
Compare
85717eb to
dfd22a2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
benchmarks/ops/benchmark_fla.py (1)
73-73: Variable scope issue persists despite past review.The variable
gis only defined when"gla" in providerbut is used whenprovider == 'gla_fused_chunk'. While the logic should work (since'gla_fused_chunk'contains "gla"), static analysis tools flag this as an error, and it creates maintenance risk.The past review comment on this issue remains valid. Consider the suggested fix to move the
gusage inside the conditional block or ensure clearer variable scoping to avoid static analysis warnings and potential future bugs.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
benchmarks/ops/benchmark_abc.py(1 hunks)benchmarks/ops/benchmark_based.py(1 hunks)benchmarks/ops/benchmark_delta_rule.py(4 hunks)benchmarks/ops/benchmark_fla.py(2 hunks)benchmarks/ops/benchmark_gla.py(2 hunks)benchmarks/ops/benchmark_gsa.py(2 hunks)benchmarks/ops/benchmark_nsa.py(1 hunks)benchmarks/ops/benchmark_retention.py(1 hunks)benchmarks/ops/benchmark_titans.py(1 hunks)benchmarks/ops/benchmark_ttt.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (9)
- benchmarks/ops/benchmark_nsa.py
- benchmarks/ops/benchmark_retention.py
- benchmarks/ops/benchmark_titans.py
- benchmarks/ops/benchmark_gsa.py
- benchmarks/ops/benchmark_abc.py
- benchmarks/ops/benchmark_delta_rule.py
- benchmarks/ops/benchmark_ttt.py
- benchmarks/ops/benchmark_gla.py
- benchmarks/ops/benchmark_based.py
🧰 Additional context used
🪛 Pylint (3.3.7)
benchmarks/ops/benchmark_fla.py
[error] 73-73: Using variable 'g' before assignment
(E0601)
🔇 Additional comments (4)
benchmarks/ops/benchmark_fla.py (4)
46-49: LGTM: Tensor shape standardization for "based" provider.The tensor shapes have been correctly standardized to
(B, T, H, 16)format, maintaining consistency with the overall refactoring effort.
56-58: LGTM: Tensor shape standardization for default case.The tensor shapes have been correctly standardized to
(B, T, H, D)format, aligning with the consistent dimension ordering adopted across the benchmark.
67-67: LGTM: Correct backward pass modification.The addition of
[0]to access the first element of the returned tuple before calling.backward()is correct and aligns with the updated function interface.
69-69: LGTM: Correct backward pass modification.The addition of
[0]to access the first element of the returned tuple before calling.backward()is correct and aligns with the updated function interface.
This PR updates 10 benchmark scripts to make them compatible with the latest interface of the FLA project.
✅ Changes:
triton.testing.do_bench()in:benchmark_abc.py,benchmark_based.py,benchmark_delta_rule.py,benchmark_fla.py,benchmark_gla.py,benchmark_gsa.py,benchmark_retention.py,benchmark_titans.py,benchmark_ttt.py,benchmark_nsa.pyproviderlogicfla.ops.*All modified benchmarks are verified to run successfully on the latest version of FLA + Triton.
Note:
used_recurrent_glandfused_recurrent_gsaencounter segmentation faults in their backward pass.此 PR 使 10 个 benchmark 脚本在最新版 flash-linear-attention 中恢复可用。改动包括接口适配、参数兼容处理等,已在 H100 上通过验证。
Summary by CodeRabbit
Refactor
Chores