-
Notifications
You must be signed in to change notification settings - Fork 322
Add fused short convolution kernel with L2 norm #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a Triton-backed fused short convolution operator with autograd (forward/backward kernels) and integrates optional per-head L2 normalization and activation into ShortConvolution; propagates new fuse_conv_l2/norm/head_dim flags through modules, layers, and configs and adds a benchmark comparing fused vs separate Conv+L2. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Model as ShortConvolution / caller
participant Fn as FusedShortConvFunction
participant Fwd as fused_short_conv_fwd_kernel (Triton)
participant Bwd as fused_short_conv_bwd_kernel (Triton)
Model->>Fn: apply(x, weight, bias?, residual?, initial_state?, head_dim?, norm?)
Fn->>Fwd: launch forward kernel (HAS_WEIGHT/BIAS/RESIDUAL/NORM/ACTIVATION flags)
Fwd->>Fwd: load inputs, optional norm per-head, activation, compute y (+final state)
Fwd-->>Fn: return y (and final_state if requested)
Fn->>Model: return y
Note over Fn,Bwd: autograd backward
Fn->>Bwd: launch backward kernel (recompute activation/norm if needed)
Bwd->>Bwd: compute dx, dweight, dbias (aggregate/atomics)
Bwd-->>Fn: gradients
Fn-->>Model: propagate gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new, highly optimized fused short convolution operation into the codebase. The primary goal is to provide a performant and memory-efficient solution for short convolutions, particularly when L2 normalization is required. By leveraging Triton for GPU acceleration and implementing a recomputation strategy for gradient calculation, this change aims to improve the overall efficiency of models utilizing this type of convolution. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new fused short convolution kernel with optional L2 normalization, implemented using Triton. The implementation is comprehensive, supporting variable-length sequences, residual connections, and initial states. The use of a recomputation strategy in the backward pass is a good choice for memory efficiency.
I've identified a few areas for improvement:
- There is some code duplication in the forward kernel that could be refactored for better maintainability.
- The block size for the feature dimension (
BD) is hardcoded in some cases, which might not be optimal for all input shapes. - The backward kernel has a memory access pattern for weight gradients (
dw) that could be optimized for better performance. - There's a minor redundancy in handling 'swish' and 'silu' activations that could be cleaned up.
Overall, this is a solid contribution. Addressing these points will further improve the code's performance and maintainability.
| if not USE_INITIAL_STATE: | ||
| for i_w in tl.static_range(-W + 1, 1): | ||
| p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | ||
| # [BT, BD] | ||
| b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) | ||
| if HAS_WEIGHT: | ||
| b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) | ||
| b_y += b_yi | ||
| elif i_t * BT >= W: | ||
| # to make Triton compiler happy, we need to copy codes | ||
| for i_w in tl.static_range(-W + 1, 1): | ||
| p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | ||
| # [BT, BD] | ||
| b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) | ||
| if HAS_WEIGHT: | ||
| b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) | ||
| b_y += b_yi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code blocks for the if not USE_INITIAL_STATE: and elif i_t * BT >= W: conditions are identical. This duplication can be avoided by combining the conditions, which would improve maintainability. The comment "to make Triton compiler happy" is a bit vague. If this duplication is strictly necessary for performance or to avoid a compiler bug, a more detailed explanation would be helpful. Otherwise, consider refactoring.
| if not USE_INITIAL_STATE: | |
| for i_w in tl.static_range(-W + 1, 1): | |
| p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | |
| # [BT, BD] | |
| b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) | |
| if HAS_WEIGHT: | |
| b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) | |
| b_y += b_yi | |
| elif i_t * BT >= W: | |
| # to make Triton compiler happy, we need to copy codes | |
| for i_w in tl.static_range(-W + 1, 1): | |
| p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | |
| # [BT, BD] | |
| b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) | |
| if HAS_WEIGHT: | |
| b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) | |
| b_y += b_yi | |
| if not USE_INITIAL_STATE or i_t * BT >= W: | |
| for i_w in tl.static_range(-W + 1, 1): | |
| p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | |
| # [BT, BD] | |
| b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) | |
| if HAS_WEIGHT: | |
| b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) | |
| b_y += b_yi |
| if ACTIVATION == 'swish' or ACTIVATION == 'silu': | ||
| b_y = b_y * tl.sigmoid(b_y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The activation functions 'swish' and 'silu' are identical. The check ACTIVATION == 'swish' or ACTIVATION == 'silu' is present here and also repeated in the backward kernel (lines 212-213 and 223-225). To improve code clarity and reduce redundancy, you could normalize the activation name in the Python wrapper function (FusedShortConvFunction). For example, you could convert 'swish' to 'silu' before passing it to the Triton kernels, and then only check for ACTIVATION == 'silu' in the kernels.
| if HAS_WEIGHT: | ||
| b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1) | ||
| b_dw = tl.sum(b_dy * b_x, 0) | ||
| tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The store operation to dw is not coalesced. The dw tensor has shape (B*NT, D, W), so the last dimension is the most contiguous. However, the store operation tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, ...) writes to elements with a stride of W in memory for consecutive values of o_d. This can significantly impact performance. To improve memory access patterns, consider changing the layout of dw in the Python wrapper to (B*NT, W, D) and then transposing it back to the required shape after the kernel execution.
| else: | ||
| BD = 32 # Default fallback or simple value since we don't autotune |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When use_norm is false, BD is hardcoded to 32. This may be suboptimal for performance if the dimension D is much smaller or larger than 32. Consider making BD more adaptive to D, similar to how it's handled when use_norm is true. For example, you could use BD = min(32, triton.next_power_of_2(D)) or include BD in the autotuning configuration.
| else: | ||
| BD = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
fla/ops/convolution/fused_short_conv.py (1)
118-155: Minor: unused kernel arguments and flags can be trimmed for clarity once features are settled.Static analysis correctly notes several unused parameters:
Bin both kernels;y,initial_state,dh0,dht,USE_INITIAL_STATE,USE_FINAL_STATEin the backward kernel.output_final_stateargument in the autogradforward.These don’t affect runtime correctness but add noise and trigger ARG00x warnings. Once you decide which state features you actually want to support, consider either:
- Removing unused parameters and heuristics from the Triton kernels and Python wrapper, or
- Wiring them through and using the flags for conditional logic.
Until then, you might suppress specific lints locally if you want to keep the signatures stable.
Also applies to: 245-269
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/convolution/fused_short_conv.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/convolution/fused_short_conv.py (3)
fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)fla/utils.py (1)
get_multiprocessor_count(360-368)fla/ops/utils/matmul.py (1)
sigmoid(158-160)
🪛 Flake8 (7.3.0)
fla/ops/convolution/fused_short_conv.py
[error] 205-205: continuation line under-indented for visual indent
(E128)
[error] 213-213: indentation is not a multiple of 4
(E111)
[error] 213-213: over-indented
(E117)
[error] 224-224: indentation is not a multiple of 4
(E111)
[error] 224-224: over-indented
(E117)
[error] 225-225: indentation is not a multiple of 4
(E111)
[error] 242-242: indentation is not a multiple of 4
(E111)
[error] 242-242: over-indented
(E117)
[error] 274-274: local variable 'shape' is assigned to but never used
(F841)
[error] 329-329: local variable 'shape' is assigned to but never used
(F841)
🪛 GitHub Actions: lint
fla/ops/convolution/fused_short_conv.py
[error] 275-275: ruff-check: F841 Local variable 'shape' is assigned to but never used.
[error] 330-330: ruff-check: F841 Local variable 'shape' is assigned to but never used.
[error] 1-1: Trailing whitespace detected by pre-commit hook; file was modified.
[error] 1-1: pre-commit: 4 errors found (2 fixed by hooks, 2 remaining); run with --unsafe-fixes to apply more.
🪛 Ruff (0.14.5)
fla/ops/convolution/fused_short_conv.py
28-28: Unused function argument: B
(ARG001)
128-128: Unused function argument: y
(ARG001)
131-131: Unused function argument: initial_state
(ARG001)
132-132: Unused function argument: dh0
(ARG001)
133-133: Unused function argument: dht
(ARG001)
140-140: Unused function argument: B
(ARG001)
151-151: Unused function argument: USE_INITIAL_STATE
(ARG001)
152-152: Unused function argument: USE_FINAL_STATE
(ARG001)
255-255: Unused static method argument: output_final_state
(ARG004)
274-274: Local variable shape is assigned to but never used
Remove assignment to unused variable shape
(F841)
329-329: Local variable shape is assigned to but never used
Remove assignment to unused variable shape
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (4)
fla/ops/convolution/fused_short_conv.py (4)
96-108: L2 norm + swish/silu backward recomputation looks consistent; main math and masking are sound.The forward L2 norm over the head dimension (Lines 102–107) and the corresponding backward logic (Lines 187–221) correctly:
- Recompute the convolution output per block.
- Apply bias and optional swish/silu nonlinearity with the right derivative
σ(x) * (1 + x * (1 - σ(x))).- Perform per‑row L2 normalization and its gradient using the dot‑product trick, while respecting
EPSand them_dmask.Given
Wis short and everything is kept in fp32, the recomputation strategy here looks reasonable and numerically sound.Also applies to: 187-221
399-407: Public wrapper API and overall fused conv+norm integration look good.The
fused_short_convwrapper cleanly forwards all options into the autogradFunction, keeps the Python API compact, and documents the recomputation strategy. Aside from the noted shape and state‑handling caveats, the core fused short‑conv + optional L2 norm design is solid and in line with the rest of the ops.
270-283: Address unusedshapevariables and implement symmetric reshape logic for >3D inputs.The review identifies two confirmed issues:
Unused
shapevariables (lines 274, 329): Bothforwardandbackwardassignshape = x.shapebut never use it, causing F841 lint failures.Asymmetric flattening (lines 275–276, 330–331): The conditional
rearrange(x, 'b t ... -> b t (...)')flattensxto 3D, but neither method reshapes outputs back to the original shape. If the flattening branch executes with >3D inputs, autograd will encounter a shape mismatch between saved gradients and the original tensor rank.Choose one approach:
- Restrict to 3D inputs (simpler): Add
assert x.ndim == 3 and x.shape[-1] == weight.shape[0]and remove bothshapeassignments.- Support >3D inputs (preferred if needed): Save original shape in
ctx, flattenx/residual/initial_state, then reshape outputs (yin forward,dxin backward) before returning.
44-53: Now I need to examine the actual file to verify the grid and kernel code. Let me search for the specific lines in question:Let me run a script to directly examine the file structure:
Based on my analysis, I need to mark this for manual verification since I cannot directly access the file content to confirm the specific code patterns. However, the concern raised is theoretically sound based on Triton kernel architecture and the web search results confirmed that variable-length batching typically uses
B=1with packed sequences (notB>1withcu_seqlens). The review comment provides a concrete, logical fix.Verify the varlen grid dimension fix for potential race conditions and duplicate work when
B > 1withcu_seqlens.When
cu_seqlens is not None(varlen mode), the kernel ignoresi_bbut the grid still launches withBinstances along the z-dimension. This may cause multiple kernel instances to compute and write to identical output locations, creating data races on shared output buffers (y,dw,db) and potentially scaling gradients by a factor ofB. The proposed fix—conditioning the z-dimension size on whether varlen is enabled—is sound and should prevent redundant work. Verify this pattern exists in the actual kernel code and test both regular batch and varlen modes to ensure the fix is correct.
| for i_w in range(0, W): | ||
| p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) | ||
| b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) | ||
|
|
||
| if USE_NORM: | ||
| # Recompute y_conv at T_global = i_t*BT + i_w + t_local | ||
| # We need to loop over k (kernel support) to compute convolution | ||
| b_y_conv = tl.zeros((BT, BD), dtype=tl.float32) | ||
| t_local = tl.arange(0, BT) | ||
|
|
||
| for k in range(0, W): | ||
| w_k = tl.sum(b_w * (o_w[None, :] == k), 1) | ||
| # Forward: y[t] = sum_{j=0}^{W-1} x[t - W + 1 + j] * w[j] | ||
| # Here t = i_t * BT + i_w + t_local, j = k | ||
| # So x index = t - W + 1 + k = (i_t * BT + i_w + t_local) - W + 1 + k | ||
| x_offset = i_t * BT + i_w - W + 1 + k | ||
| m_x_valid = (x_offset + t_local >= 0) & (x_offset + t_local < T) | ||
|
|
||
| # We need to reload x from memory as it's not in registers. | ||
| # Constructing pointers manually to allow random access in loop | ||
| # This is efficient enough for small W. | ||
| val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :], | ||
| mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32) | ||
| b_y_conv += val_x * w_k[None, :] | ||
|
|
||
| if HAS_BIAS: | ||
| b_y_conv += tl.load(bias + o_d, mask=m_d).to(tl.float32) | ||
|
|
||
| b_y_act = b_y_conv | ||
| if ACTIVATION == 'swish' or ACTIVATION == 'silu': | ||
| b_y_act = b_y_conv * tl.sigmoid(b_y_conv) | ||
|
|
||
| b_var = tl.sum(b_y_act * b_y_act, 1) | ||
| b_std = tl.sqrt(b_var + EPS) | ||
| b_inv_std = 1.0 / b_std | ||
| b_y_out = b_y_act * b_inv_std[:, None] | ||
| b_dot = tl.sum(b_dy * b_y_out, 1) | ||
| b_dy = (b_dy - b_y_out * b_dot[:, None]) * b_inv_std[:, None] | ||
|
|
||
| # For activation backward | ||
| if ACTIVATION == 'swish' or ACTIVATION == 'silu': | ||
| b_sig = tl.sigmoid(b_y_conv) | ||
| b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig)) | ||
|
|
||
| b_wdy = b_dy | ||
| if HAS_WEIGHT: | ||
| b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1) | ||
| b_dw = tl.sum(b_dy * b_x, 0) | ||
| tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d) | ||
|
|
||
| if HAS_BIAS and i_w == 0: | ||
| b_db += tl.sum(b_dy, 0) | ||
|
|
||
| b_dx += b_wdy | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Backward kernel currently ignores initial_state/dh0/dht; output_final_state is accepted but not implemented.
In fused_short_conv_bwd_kernel, the parameters initial_state, dh0, and dht (Lines 131–133) and the flags USE_INITIAL_STATE / USE_FINAL_STATE are unused; gradients for the initial state and final state are effectively always zero. In the Python autograd wrapper:
forwardalways returns(y, None)(Line 317), regardless ofoutput_final_state.backwardacceptsdhtbut does nothing with it and always returnsdh0=None(Line 347, 389).
If initial/final state gradients or a real final state output are not supported yet, it is safer to:
- Explicitly assert that
output_final_stateisFalseandinitial_state.requires_gradisFalse(or raiseNotImplementedErrorwhen violated). - Drop unused heuristics/args from the Triton kernel until they’re actually used, to avoid confusion.
Alternatively, implement the missing state logic and wire USE_INITIAL_STATE/USE_FINAL_STATE through the kernels.
Also applies to: 354-387
I'll generate shell scripts to verify the claims in the review comment about unused parameters and unimplemented features.
---
🏁 Script executed:
```shell
#!/bin/bash
# Verify backward kernel parameter declarations and usage
echo "=== Checking backward kernel signature and parameters ==="
rg -n "def fused_short_conv_bwd_kernel|initial_state|dh0|dht|USE_INITIAL_STATE|USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -A 2 -B 2
Length of output: 3729
🏁 Script executed:
#!/bin/bash
# Check forward function signature and return statements
echo "=== Checking forward function behavior ==="
rg -n "def forward|return.*output_final_state|output_final_state.*return" fla/ops/convolution/fused_short_conv.py -A 3 -B 1Length of output: 164
🏁 Script executed:
#!/bin/bash
# Check backward function signature and return statements
echo "=== Checking backward function behavior ==="
rg -n "def backward|return dh0|return.*dh0" fla/ops/convolution/fused_short_conv.py -A 2 -B 1Length of output: 383
🏁 Script executed:
#!/bin/bash
# Get line ranges mentioned in the review to examine full context
echo "=== Lines 131-133 (kernel parameters) ==="
sed -n '131,133p' fla/ops/convolution/fused_short_conv.py
echo -e "\n=== Lines 317 (forward return) ==="
sed -n '310,325p' fla/ops/convolution/fused_short_conv.py
echo -e "\n=== Lines 347, 389 (backward) ==="
sed -n '340,395p' fla/ops/convolution/fused_short_conv.py
echo -e "\n=== Lines 354-387 (also applies to) ==="
sed -n '354,387p' fla/ops/convolution/fused_short_conv.pyLength of output: 3410
🏁 Script executed:
#!/bin/bash
# Search for any conditional logic using these flags in the kernel
echo "=== Search for flag usage patterns ==="
rg -n "USE_INITIAL_STATE|USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -B 2 -A 2Length of output: 1274
Let me search more thoroughly to ensure these flags aren't used anywhere in the backward kernel body:
---
🏁 Script executed:
```shell
#!/bin/bash
# Get the complete backward kernel body to inspect for flag usage
echo "=== Full backward kernel body (lines 126-237) ==="
sed -n '126,237p' fla/ops/convolution/fused_short_conv.py
Length of output: 4176
🏁 Script executed:
#!/bin/bash
# Search for any conditional branching on these flags in backward kernel
echo "=== Search for if statements using USE_INITIAL_STATE or USE_FINAL_STATE in backward context ==="
rg -n "if.*USE_INITIAL_STATE|if.*USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -B 1 -A 2Length of output: 343
🏁 Script executed:
#!/bin/bash
# Verify: does backward kernel ever reference dh0 or dht besides parameters?
echo "=== References to dh0 in backward context (after line 126) ==="
awk 'NR>=126 && NR<=237 {print NR": "$0}' fla/ops/convolution/fused_short_conv.py | grep -i "dh0\|dht"Length of output: 90
🏁 Script executed:
#!/bin/bash
# Check what parameters backward kernel accepts and how many it should return
echo "=== Backward function signature and return statement structure ==="
awk 'NR>=320 && NR<=330 {print NR": "$0}' fla/ops/convolution/fused_short_conv.py
echo ""
echo "=== Forward function signature to understand parameter count ==="
awk 'NR>=248 && NR<=270 {print NR": "$0}' fla/ops/convolution/fused_short_conv.pyLength of output: 1465
Implement missing backward pass state gradients or add explicit NotImplementedError guards.
The verification confirms all claims: the backward kernel declares USE_INITIAL_STATE, USE_FINAL_STATE, dh0, and dht (lines 131–133, 151–152) but never uses them in the kernel body (lines 126–237). The Python wrapper accepts dht in backward() but always returns dh0=None (line 389). Additionally, output_final_state is accepted but hardcoded to return None (line 317).
Choose one approach:
- Implement: Wire
USE_INITIAL_STATE/USE_FINAL_STATEconditionals through the backward kernel and compute/return proper gradients for initial/final states. - Guard: Add assertions in the Python wrapper to reject unsupported cases:
assert not output_final_state, "output_final_state not implemented"assert not (initial_state is not None and initial_state.requires_grad), "initial_state gradients not implemented"
Leaving unused parameters and flags creates API confusion and makes future maintenance harder.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 205-205: continuation line under-indented for visual indent
(E128)
[error] 213-213: indentation is not a multiple of 4
(E111)
[error] 213-213: over-indented
(E117)
[error] 224-224: indentation is not a multiple of 4
(E111)
[error] 224-224: over-indented
(E117)
[error] 225-225: indentation is not a multiple of 4
(E111)
🤖 Prompt for AI Agents
In fla/ops/convolution/fused_short_conv.py around lines 183-237 (flags declared
around 131-133 and 151-152; Python wrapper behavior at ~317 and ~389): the
kernel and wrapper declare/accept USE_INITIAL_STATE/USE_FINAL_STATE, dh0 and dht
and output_final_state but never compute or return those gradients; either
implement state-gradient flow or explicitly guard/disable the feature. Option A
(guard): add assertions in the Python wrapper to reject output_final_state=True
and initial_state.requires_grad (e.g., assert not output_final_state,
"output_final_state not implemented"; assert not (initial_state is not None and
initial_state.requires_grad), "initial_state gradients not implemented"), and
clear/remove the unused kernel flags/parameters to avoid dead API. Option B
(implement): wire state gradients through the kernel by accumulating gradients
for initial and final states similarly to b_dx (compute per-step contributions
to dht and dh0 based on recurrence and normalization/activation derivatives),
store them into the kernel output buffers, and update the Python wrapper to
return concrete dh0/dht tensors instead of None; ensure shapes/dtypes match and
guarded by USE_INITIAL_STATE/USE_FINAL_STATE. Choose one approach and apply the
corresponding changes consistently in both kernel and wrapper.
| val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :], | ||
| mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32) | ||
| b_y_conv += val_x * w_k[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Fix indentation and continuation formatting to satisfy flake8 and pre‑commit.
The lint errors correspond to:
- The continuation line for
val_x(Line 205) being under‑indented for its visual indent. - Over‑indented lines for
b_y_act,b_sig,b_dy(Lines 213, 224, 225). - Over‑indented
tl.storefordb(Line 242). - Plus trailing whitespace noted by pre‑commit around the file header.
A minimal compliant patch would look like:
- val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
- mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32)
+ val_x = tl.load(
+ x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
+ mask=m_x_valid[:, None] & m_d[None, :],
+ other=0.0,
+ ).to(tl.float32)
@@
- b_y_act = b_y_conv
- if ACTIVATION == 'swish' or ACTIVATION == 'silu':
- b_y_act = b_y_conv * tl.sigmoid(b_y_conv)
+ b_y_act = b_y_conv
+ if ACTIVATION == 'swish' or ACTIVATION == 'silu':
+ b_y_act = b_y_conv * tl.sigmoid(b_y_conv)
@@
- if ACTIVATION == 'swish' or ACTIVATION == 'silu':
- b_sig = tl.sigmoid(b_y_conv)
- b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig))
+ if ACTIVATION == 'swish' or ACTIVATION == 'silu':
+ b_sig = tl.sigmoid(b_y_conv)
+ b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig))
@@
- if HAS_BIAS:
- tl.store(db + i_tg * D + o_d, b_db.to(db.dtype.element_ty), mask=m_d)
+ if HAS_BIAS:
+ tl.store(db + i_tg * D + o_d, b_db.to(db.dtype.element_ty), mask=m_d)…and remove any trailing whitespace on the file header line. This should clear the flake8 and pre‑commit errors listed in the pipeline.
Also applies to: 211-226, 241-242
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 205-205: continuation line under-indented for visual indent
(E128)
🤖 Prompt for AI Agents
In fla/ops/convolution/fused_short_conv.py around lines 204-206 (and also apply
to 211-226, 241-242), fix indentation and line-continuation formatting to
satisfy flake8: align the continued tl.load call so the continuation expression
is indented to the visual indent level (e.g., align with the opening parenthesis
or add one extra indent level), adjust the indentation of subsequent assignments
b_y_act, b_sig, b_dy to match their block scope (remove the extra spaces), align
the tl.store for db with its surrounding block (remove over-indentation), and
remove any trailing whitespace on the file header line; run the linter to verify
no further continuation/indentation issues remain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/modules/benchmark_fused_conv_l2.py(1 hunks)fla/layers/delta_net.py(5 hunks)fla/modules/convolution.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
fla/layers/delta_net.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/modules/convolution.py (1)
fla/ops/convolution/fused_short_conv.py (1)
fused_short_conv(391-413)
benchmarks/modules/benchmark_fused_conv_l2.py (2)
fla/modules/convolution.py (2)
ShortConvolution(806-1071)backward(675-689)fla/modules/l2norm.py (1)
l2norm(263-268)
🪛 GitHub Actions: lint
benchmarks/modules/benchmark_fused_conv_l2.py
[error] 1-1: Pre-commit trailing-whitespace hook detected issues and fixed them.
🪛 Ruff (0.14.5)
fla/layers/delta_net.py
91-91: Unused method argument: kwargs
(ARG002)
fla/modules/convolution.py
982-982: Avoid specifying long messages outside the exception class
(TRY003)
1037-1037: Avoid specifying long messages outside the exception class
(TRY003)
1062-1062: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (1)
fla/modules/convolution.py (1)
834-842: Fused L2 normalization integration inShortConvolutionlooks consistentThe way
ShortConvolutionnow handles normalization andhead_dimis coherent:
norm/norm_epsare stored and validated so that onlynorm='l2'with the Triton backend is allowed at construction.- In
forward, enablingnormswitches to thefused_short_convpath and enforceshead_dimat call time, which matches how DeltaNet and the benchmark supply per‑head dimensions.- In
step, both Triton and CUDA update paths apply the same post‑conv L2 normalization whennormis set, again guarded byhead_dim, so streaming and full‑sequence behavior stay aligned.Given the current usages (only Q/K convs set
normand they all passhead_dim), this design is sound and doesn’t disturb existingnorm=Nonecall sites.Also applies to: 860-870, 909-997, 1012-1067
| # Benchmark Combined | ||
| print("\n" + "="*80) | ||
| print("Forward + Backward Pass") | ||
| print("="*80) | ||
|
|
||
| def combined_sep(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y = separate_conv_l2(x, conv_separate, head_dim) | ||
| y.backward(grad_sep, retain_graph=True) | ||
|
|
||
| t_sep_combined = benchmark.Timer( | ||
| stmt="combined_sep()", | ||
| globals={"combined_sep": combined_sep}, | ||
| ) | ||
| m_sep_combined = t_sep_combined.timeit(100) | ||
| print(f"Separate: {m_sep_combined}") | ||
|
|
||
| def combined_fused(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y = fused_conv_l2(x, conv_fused, head_dim) | ||
| y.backward(grad_fused, retain_graph=True) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Avoid retain_graph=True in the combined forward+backward benchmarks
In combined_sep / combined_fused you rebuild the graph on every call:
def combined_sep():
...
y = separate_conv_l2(x, conv_separate, head_dim)
y.backward(grad_sep, retain_graph=True)
def combined_fused():
...
y = fused_conv_l2(x, conv_fused, head_dim)
y.backward(grad_fused, retain_graph=True)Because each benchmark iteration creates a new graph, retain_graph=True keeps all of them alive, causing linear memory growth and potentially OOM for these tensor sizes, as well as skewing timings.
You only need retain_graph=True in the backward‑only benchmarks where the same forward graph is reused; here you can safely drop it:
- y = separate_conv_l2(x, conv_separate, head_dim)
- y.backward(grad_sep, retain_graph=True)
+ y = separate_conv_l2(x, conv_separate, head_dim)
+ y.backward(grad_sep)
...
- y = fused_conv_l2(x, conv_fused, head_dim)
- y.backward(grad_fused, retain_graph=True)
+ y = fused_conv_l2(x, conv_fused, head_dim)
+ y.backward(grad_fused)I'll verify the review comment by examining the actual code structure and backward/combined benchmark implementations.
🏁 Script executed:
#!/bin/bash
# Examine the benchmark file structure around the claimed lines
cat benchmarks/modules/benchmark_fused_conv_l2.py | head -150 | tail -60Length of output: 1801
🏁 Script executed:
#!/bin/bash
# Check the backward-only benchmarks mentioned in the scratchpad (lines ~92-97, 108-113)
sed -n '80,120p' benchmarks/modules/benchmark_fused_conv_l2.pyLength of output: 1209
🏁 Script executed:
#!/bin/bash
# Get a full view of the entire benchmark file to understand structure
wc -l benchmarks/modules/benchmark_fused_conv_l2.pyLength of output: 49
🏁 Script executed:
#!/bin/bash
# Search for all retain_graph occurrences in the file
rg -n "retain_graph" benchmarks/modules/benchmark_fused_conv_l2.py -C 3Length of output: 1071
🏁 Script executed:
#!/bin/bash
# Look for the forward-only benchmark section to confirm structure
sed -n '1,80p' benchmarks/modules/benchmark_fused_conv_l2.py | tail -40Length of output: 1088
Remove retain_graph=True from the combined forward+backward benchmarks
The code inspection confirms the review's analysis. The backward-only benchmarks (lines 89–112) correctly use retain_graph=True because they precompute the forward pass once and then reuse the same graph for 100 backward iterations. However, the combined benchmarks (lines 127–131 and 141–145) rebuild the forward graph on each timer iteration, making retain_graph=True unnecessary and harmful—all 100 graphs remain in memory simultaneously, causing linear memory growth and timing distortion.
Remove retain_graph=True from combined_sep() and combined_fused() at lines 131 and 145:
def combined_sep():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = separate_conv_l2(x, conv_separate, head_dim)
- y.backward(grad_sep, retain_graph=True)
+ y.backward(grad_sep)
def combined_fused():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = fused_conv_l2(x, conv_fused, head_dim)
- y.backward(grad_fused, retain_graph=True)
+ y.backward(grad_fused)🤖 Prompt for AI Agents
In benchmarks/modules/benchmark_fused_conv_l2.py around lines 121 to 146, the
combined forward+backward benchmark functions combined_sep() and
combined_fused() incorrectly pass retain_graph=True to backward, which causes
every iteration to keep its autograd graph in memory; remove the
retain_graph=True argument from both y.backward(...) calls (lines ~131 and ~145)
so backward() uses its default behavior and the graph is freed each iteration;
keep the existing grad reset logic intact and do not add retain_graph elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
fla/layers/kda.py (1)
74-75:fuse_conv_l2gating and L2 placement look logically soundThe interplay between
self.fuse_conv_l2,ShortConvolution(norm='l2', head_dim=...), anduse_qk_l2norm_in_kernel=not self.fuse_conv_l2is consistent:
use_short_conv=False⇒self.fuse_conv_l2=False, KDA kernel keeps handling q/k L2 as before.use_short_conv=True, fuse_conv_l2=True⇒ L2 done insideShortConvolution(fused conv+L2), and KDA kernel L2 is disabled, avoiding double normalization.use_short_conv=True, fuse_conv_l2=False⇒ short conv without norm, KDA kernel L2 stays enabled (restoring prior behavior).This cleanly ensures exactly one L2 path is active at a time. You might optionally add a short note in the class docstring explaining that
fuse_conv_l2moves q/k L2 from the KDA kernel into the short-conv path when enabled.Also applies to: 87-88, 121-137, 199-211, 240-262
fla/layers/delta_net.py (1)
72-107: fuse_conv_l2 / kernel L2 interplay looks correct; consider minor cleanupsThe new wiring in
DeltaNetkeeps L2 normalization semantics consistent while enabling the fused conv path:
self.fuse_conv_l2 = fuse_conv_l2 and use_short_conv and (qk_norm == 'l2')ensures the fused conv+L2 path only activates when both short conv and L2 q/k norm are actually in use, and it naturally falls back whenuse_short_conv=Falseorqk_norm!='l2'.- The deprecated
fuse_normalias mapping intofuse_conv_l2with a warning gives a smooth migration path for older call sites.- Passing
norm='l2' if self.fuse_conv_l2 else Noneandnorm_eps=norm_epsinto the q/kShortConvolutionmodules, and thenhead_dim=self.head_k_dim if self.fuse_conv_l2 else Noneinforward, lines up withShortConvolution’s fused‑norm contract (head_dim required only when norm is active).- Using
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2)in bothfused_recurrent_delta_ruleandchunk_delta_ruleensures q/k are normalized exactly once: in the fused conv when fusion is on, or in the kernel when it’s off—including theuse_short_conv=Falsecase, which preserves the original behavior.Minor follow‑ups you might consider:
__init__(..., **kwargs)currently doesn’t usekwargsand triggers Ruff’s ARG002; if you want to keep the flexible signature, adding a smalldel kwargs # noqa: ARG002or similar would silence the warning without changing behavior.- The class docstring doesn’t yet mention
fuse_conv_l2/fuse_norm; adding a brief arg description would help users understand how to toggle the fused conv+L2 path.Also applies to: 139-159, 208-225, 261-283
fla/layers/mom.py (1)
689-705: shared_o path mirrors main path for fused conv L2 behaviorThe shared output branch correctly reuses q/k ShortConvolution with
head_dimgated byself.fuse_conv_l2and passesuse_qk_l2norm_in_kernel = not self.fuse_conv_l2to both chunk and fused-recurrent kernels, keeping fused vs. unfused semantics aligned between main and shared branches.You may later want to factor the repeated q/k-conv + kernel-call wiring into a small helper to reduce duplication across the main and
shared_opaths.Also applies to: 719-742
fla/layers/gated_deltaproduct.py (1)
196-220: Forward conv path and kernel flags maintain correct L2-placement with householder structurePassing
head_dim=self.head_k_diminto q/k ShortConvolution only whenself.fuse_conv_l2is enabled, and settinguse_qk_l2norm_in_kernel = not self.fuse_conv_l2for bothchunk_gated_delta_productandfused_recurrent_gated_delta_rule, ensures:
- Q/K are normalized once (in conv or kernel, not both),
- the additional
num_householderfactor is handled via reshapes without breaking per-head normalization semantics.Consider a small shared helper for the repeated “q/k conv + head_dim + use_qk_l2norm_in_kernel” wiring across GatedDeltaNet/GatedDeltaProduct to reduce copy-paste when adding future fusion options.
Also applies to: 242-276
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
fla/layers/comba.py(6 hunks)fla/layers/delta_net.py(5 hunks)fla/layers/gated_deltanet.py(6 hunks)fla/layers/gated_deltaproduct.py(6 hunks)fla/layers/kda.py(6 hunks)fla/layers/mesa_net.py(5 hunks)fla/layers/mom.py(10 hunks)fla/models/comba/configuration_comba.py(2 hunks)fla/models/comba/modeling_comba.py(1 hunks)fla/models/delta_net/configuration_delta_net.py(2 hunks)fla/models/delta_net/modeling_delta_net.py(1 hunks)fla/models/gated_deltanet/configuration_gated_deltanet.py(2 hunks)fla/models/gated_deltanet/modeling_gated_deltanet.py(1 hunks)fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py(2 hunks)fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py(1 hunks)fla/models/kda/configuration_kda.py(2 hunks)fla/models/kda/modeling_kda.py(1 hunks)fla/models/mesa_net/configuration_mesa_net.py(2 hunks)fla/models/mesa_net/modeling_mesa_net.py(1 hunks)fla/models/mom/configuration_mom.py(2 hunks)fla/models/mom/modeling_mom.py(1 hunks)fla/ops/convolution/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
fla/layers/comba.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/layers/mom.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/ops/convolution/__init__.py (1)
fla/ops/convolution/fused_short_conv.py (1)
fused_short_conv(391-413)
fla/layers/kda.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/layers/mesa_net.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/layers/gated_deltaproduct.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/layers/gated_deltanet.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
fla/layers/delta_net.py (1)
fla/modules/convolution.py (1)
ShortConvolution(806-1071)
🪛 GitHub Actions: lint
fla/ops/convolution/__init__.py
[error] 1-1: End-of-file fixer modified the file to ensure proper EOF newline.
🪛 Ruff (0.14.5)
fla/layers/delta_net.py
92-92: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (30)
fla/ops/convolution/__init__.py (1)
1-3: Re-export looks goodSimple, correct exposure of
fused_short_convvia the package namespace; nothing else to change here.fla/models/kda/configuration_kda.py (1)
35-63: Config wiring forfuse_conv_l2is consistentAdding
fuse_conv_l2at the end of the signature and storing it alongside otherfuse_*flags is backward compatible for positional callers and matches how it’s used downstream.fla/models/mom/modeling_mom.py (1)
134-151: MoM attention now correctly receivesfuse_conv_l2Forwarding
config.fuse_conv_l2intoMomAttentioncompletes the config-driven wiring for this backend without altering existing control flow.fla/models/comba/modeling_comba.py (1)
56-69: Comba block fuse_conv_l2 propagation is consistentPassing
config.fuse_conv_l2intoCombaaligns this model with the rest of the stack and doesn’t disturb existing behavior when the flag keeps its default.fla/models/mom/configuration_mom.py (1)
44-44: LGTM!The addition of the
fuse_conv_l2configuration parameter follows the established pattern for other fusion flags in the codebase and is properly stored as an instance attribute.Also applies to: 79-79
fla/models/delta_net/modeling_delta_net.py (1)
69-69: LGTM!The parameter is correctly forwarded from the configuration to the DeltaNet layer constructor, maintaining consistency with other configuration options.
fla/models/kda/modeling_kda.py (1)
68-68: LGTM!The parameter is correctly forwarded from the configuration to the KimiDeltaAttention layer constructor.
fla/models/comba/configuration_comba.py (1)
39-39: LGTM!The addition of the
fuse_conv_l2configuration parameter is consistent with the pattern established in other configuration classes.Also applies to: 71-71
fla/models/delta_net/configuration_delta_net.py (1)
40-40: LGTM!The addition of the
fuse_conv_l2configuration parameter follows the established pattern and is properly initialized.Also applies to: 71-71
fla/models/gated_deltanet/modeling_gated_deltanet.py (1)
69-69: LGTM!The parameter is correctly forwarded from the configuration to the GatedDeltaNet layer constructor.
fla/models/mesa_net/configuration_mesa_net.py (1)
42-42: LGTM!The addition of the
fuse_conv_l2configuration parameter is consistent with the pattern used in other configuration classes.Also applies to: 72-72
fla/layers/mesa_net.py (5)
69-69: LGTM!The
fuse_conv_l2parameter is properly added and stored as an instance attribute, following the established pattern in the codebase.Also applies to: 90-90
106-121: LGTM!The ShortConvolution layers are correctly configured with L2 normalization when
fuse_conv_l2is enabled. The conditionalnorm='l2' if self.fuse_conv_l2 else Nonepattern properly enables/disables the fused normalization path, andnorm_epsis consistently provided for both q and k convolutions.
158-171: LGTM!The
head_dimparameter is correctly passed conditionally based onfuse_conv_l2. When fusion is enabled,head_dimis provided to enable per-head normalization within the convolution operation. The pattern is consistent for both q and k projections.
185-197: LGTM!The
use_qk_l2norm_in_kernelflag is correctly set tonot self.fuse_conv_l2, ensuring that L2 normalization is only applied in the kernel when it's not already fused into the convolution operation. This prevents double normalization.
200-202: LGTM!The decoding path correctly applies L2 normalization to q and k only when
fuse_conv_l2is False, maintaining consistency with the training path and preventing redundant normalization when fusion is enabled.fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (1)
56-71: fuse_conv_l2 correctly threaded into GatedDeltaProductPassing
fuse_conv_l2=config.fuse_conv_l2intoGatedDeltaProductkeeps the new fused short‑conv+L2 behavior configurable at the model level and consistent with the updated config; the change is localized and backward compatible with the default value on the config side.fla/models/mesa_net/modeling_mesa_net.py (1)
56-70: Consistent propagation of fuse_conv_l2 into MesaNetForwarding
fuse_conv_l2=config.fuse_conv_l2intoMesaNetmatches the new layer API and keeps the fusion toggle controllable from the config without altering existing call sites.fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py (1)
11-45: New fuse_conv_l2 config flag is cleanly integratedAdding
fuse_conv_l2: bool = Trueand storing it asself.fuse_conv_l2gives a clear config‑level switch for the fused short‑conv+L2 path while preserving backward compatibility via the defaulted argument. If you keep separate config docs/JSON schema, consider listing this flag there as well for discoverability.Also applies to: 65-70
fla/layers/comba.py (1)
77-96: Comba’s fuse_conv_l2 gating is consistent and preserves prior behavior
Comba’s newfuse_conv_l2path looks well‑designed:
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_convcleanly disables fusion when short conv is off.- q/k
ShortConvolutionlayers usenorm='l2'andnorm_eps=norm_epsonly when fusion is enabled, andforwardpasseshead_dim=self.head_k_dimin that case, aligning withShortConvolution’s fused‑norm requirements.use_qk_l2norm_in_kernel=not self.fuse_conv_l2in bothchunk_combaandfused_recurrent_combaensures q/k L2 normalization happens either in the conv or in the kernel, but not both, and that turning offfuse_conv_l2(oruse_short_conv) reverts to the original in‑kernel‑only scheme.Also applies to: 104-120, 177-200, 248-260, 291-316
fla/layers/mom.py (4)
279-317: fuse_conv_l2 flag wiring in MomAttention constructor is coherentConditioning
self.fuse_conv_l2onself.use_short_convkeeps fusion disabled when the conv path is off, avoiding inconsistent states; ctor signature and attribute setup look consistent with the rest of the class.
379-396: ShortConvolution q/k initialization correctly gates fused L2 normalizationPassing
norm='l2'andnorm_epsonly whenself.fuse_conv_l2is enabled cleanly switches between fused and unfused conv paths while leavingv_conv1duntouched, which matches the intended “Q/K-only” normalization.
517-545: head_dim propagation into q/k ShortConvolution calls matches fused L2 expectationsUsing
head_dim=self.head_qk_dim if self.fuse_conv_l2 else Noneensures the fused path receives the per-head size it needs, while the unfused path behaves as before. Givenself.key_dim = num_heads * head_dim, the hidden size/head_dim ratio is well-defined for both q and k.
583-618: Kernel-side L2 normalization flag is toggled consistently with conv fusionSetting
use_qk_l2norm_in_kernel = not self.fuse_conv_l2in bothchunk_gated_delta_ruleandfused_recurrent_gated_delta_ruleensures Q/K are L2-normalized either in the fused conv or in the kernel, but not both, preserving a single normalization step.fla/layers/gated_deltanet.py (3)
88-118: Constructor fuse_conv_l2 integration for GatedDeltaNet is consistentAdding
fuse_conv_l2and storing it asself.fuse_conv_l2 = fuse_conv_l2 and self.use_short_convcleanly ties fusion to the presence of short convolutions without changing existing defaults or validation logic.
172-195: q/k ShortConvolution now correctly opt into fused L2Conditionally enabling
norm='l2'withnorm_epsonq_conv1dandk_conv1d(while keepingv_conv1dunchanged) matches the intended design of normalizing only Q/K and matches the constructor’sfuse_conv_l2flag.
239-263: Forward conv and kernel flags preserve single-site L2 normalizationPassing
head_dim=self.head_k_dim if self.fuse_conv_l2 else Noneinto q/k ShortConvolution and settinguse_qk_l2norm_in_kernel = not self.fuse_conv_l2for both chunk and fused-recurrent kernels ensures:
- fused mode: conv handles Q/K L2, kernels skip it;
- unfused mode: conv is plain, kernels perform L2 as before.
This keeps behavior consistent while enabling the fused path.Also applies to: 281-305
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
37-67: Config-level fuse_conv_l2 flag is wired consistently
GatedDeltaNetConfignow exposesfuse_conv_l2with a sensible default and stores it asself.fuse_conv_l2, matching the pattern used for other fusion-related flags and enabling clean propagation into the layer implementation.fla/layers/gated_deltaproduct.py (2)
30-66: GatedDeltaProduct constructor integrates fuse_conv_l2 cleanlyIntroducing
fuse_conv_l2and binding it toself.fuse_conv_l2 = fuse_conv_l2 and self.use_short_convmirrors the pattern used in other layers and keeps fusion logically tied to the short-conv pathway.
120-144: ShortConvolution setup for Q/K correctly supports fused L2 with householder expansionUsing
norm='l2' if self.fuse_conv_l2 else Noneandnorm_epsonq_conv1dandk_conv1d—withhidden_sizeset tokey_dimandkey_dim * num_householderrespectively—keeps the per-head L2 normalization well-defined even when keys/values are expanded over multiple householder transforms.
Summary by CodeRabbit
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.