Skip to content

Conversation

@sustcsonglin
Copy link
Collaborator

@sustcsonglin sustcsonglin commented Nov 24, 2025

Summary by CodeRabbit

  • New Features

    • High-performance fused short convolution with PyTorch autograd support, optional bias/residual/initial-state, variable-length support, and activation.
    • Built-in head-wise L2 normalization with head-dimension option; ShortConvolution exposes norm/norm_eps and propagates head_dim through forward/step.
    • New fuse_conv_l2 config/constructor flag added across multiple layers/models to enable fused conv+L2 behavior; fused_short_conv exported from ops.
  • Chores

    • Added benchmark comparing fused vs. separate conv+L2 (forward/backward/combined).

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 24, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a Triton-backed fused short convolution operator with autograd (forward/backward kernels) and integrates optional per-head L2 normalization and activation into ShortConvolution; propagates new fuse_conv_l2/norm/head_dim flags through modules, layers, and configs and adds a benchmark comparing fused vs separate Conv+L2.

Changes

Cohort / File(s) Summary
Fused short convolution op
fla/ops/convolution/fused_short_conv.py
New Triton-backed fused short convolution implementation with forward/backward kernels and a PyTorch autograd Function; supports optional weight/bias, residuals, initial state, variable-length sequences, optional per-head L2 normalization (head_dim), and activation (swish/silu).
Ops package export
fla/ops/convolution/__init__.py
Re-exports fused_short_conv from the new module.
ShortConvolution core (norm integration)
fla/modules/convolution.py
Added norm and norm_eps args to ShortConvolution; threaded head_dim through forward/step; select fused_short_conv when norm='l2' (requires Triton/backend) and apply L2 normalization when configured.
Benchmark: fused conv + L2
benchmarks/modules/benchmark_fused_conv_l2.py
New benchmark comparing separate Conv+L2 versus fused Conv with L2 baked in for forward, backward, and combined timing and memory estimates.
Delta/Comba/Gated/Mesa/MoM/KDA layers
fla/layers/delta_net.py, fla/layers/comba.py, fla/layers/gated_deltanet.py, fla/layers/gated_deltaproduct.py, fla/layers/mesa_net.py, fla/layers/mom.py, fla/layers/kda.py
Added fuse_conv_l2 (or fuse_norm mapping) parameter to constructors; when use_short_conv enabled, conditionally pass norm='l2', norm_eps, and head_dim into ShortConvolution; flip in-kernel L2-norm flag (use_qk_l2norm_in_kernel = not fuse_conv_l2) to choose fused vs in-kernel norm behavior.
Model configs & model blocks propagation
fla/models/*/configuration_*.py, fla/models/*/modeling_*.py (Comba, DeltaNet, GatedDeltaNet, GatedDeltaProduct, KDA, MesaNet, Mom)
Added fuse_conv_l2: bool = True to many model configuration classes and propagated config.fuse_conv_l2 into corresponding layer/block constructors.
Layer constructor signature change
fla/layers/delta_net.py (DeltaNet.init)
Constructor signature updated to accept fuse_conv_l2 (with compatibility handling for older fuse_norm param).

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Model as ShortConvolution / caller
    participant Fn as FusedShortConvFunction
    participant Fwd as fused_short_conv_fwd_kernel (Triton)
    participant Bwd as fused_short_conv_bwd_kernel (Triton)

    Model->>Fn: apply(x, weight, bias?, residual?, initial_state?, head_dim?, norm?)
    Fn->>Fwd: launch forward kernel (HAS_WEIGHT/BIAS/RESIDUAL/NORM/ACTIVATION flags)
    Fwd->>Fwd: load inputs, optional norm per-head, activation, compute y (+final state)
    Fwd-->>Fn: return y (and final_state if requested)
    Fn->>Model: return y

    Note over Fn,Bwd: autograd backward
    Fn->>Bwd: launch backward kernel (recompute activation/norm if needed)
    Bwd->>Bwd: compute dx, dweight, dbias (aggregate/atomics)
    Bwd-->>Fn: gradients
    Fn-->>Model: propagate gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • Triton kernel compile-time flag combinations and correctness of branching in fused_short_conv.py.
    • Backward recomputation strategy handling of activation and per-head L2 normalization.
    • Memory layout, head_dim reshaping, and interaction with existing convolution data formats (forward and step APIs).
    • API surface changes: ShortConvolution constructor/forward/step signatures and many configs/layer constructors.
    • Benchmark validity: ensuring weight copies and measured scenarios are comparable.

Poem

🐰 I hopped into Triton, kernels snug and tight,
Fused convs and L2 so outputs shine bright.
Residuals tucked, activations sing,
Gradients rebound — carrots on spring! 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.89% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the primary addition: a fused short convolution kernel implementation with integrated L2 normalization support.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fuse-conv-l2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new, highly optimized fused short convolution operation into the codebase. The primary goal is to provide a performant and memory-efficient solution for short convolutions, particularly when L2 normalization is required. By leveraging Triton for GPU acceleration and implementing a recomputation strategy for gradient calculation, this change aims to improve the overall efficiency of models utilizing this type of convolution.

Highlights

  • New Fused Short Convolution Kernel: Introduces a new highly optimized fused short convolution kernel implemented using Triton, designed for efficient computation on GPUs.
  • L2 Normalization Support: The kernel now includes optional L2 normalization, applied over the head dimension, enhancing its utility for various neural network architectures.
  • Memory-Efficient Backward Pass: The backward pass employs a recomputation strategy, where intermediate activations are recomputed on-the-fly instead of being stored, leading to significant memory savings without sacrificing performance.
  • Comprehensive Feature Set: Supports optional weight, bias, residual connections, initial states, and variable sequence lengths. It also includes 'swish' and 'silu' activation functions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fused short convolution kernel with optional L2 normalization, implemented using Triton. The implementation is comprehensive, supporting variable-length sequences, residual connections, and initial states. The use of a recomputation strategy in the backward pass is a good choice for memory efficiency.

I've identified a few areas for improvement:

  • There is some code duplication in the forward kernel that could be refactored for better maintainability.
  • The block size for the feature dimension (BD) is hardcoded in some cases, which might not be optimal for all input shapes.
  • The backward kernel has a memory access pattern for weight gradients (dw) that could be optimized for better performance.
  • There's a minor redundancy in handling 'swish' and 'silu' activations that could be cleaned up.

Overall, this is a solid contribution. Addressing these points will further improve the code's performance and maintainability.

Comment on lines +64 to +80
if not USE_INITIAL_STATE:
for i_w in tl.static_range(-W + 1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
# [BT, BD]
b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
b_y += b_yi
elif i_t * BT >= W:
# to make Triton compiler happy, we need to copy codes
for i_w in tl.static_range(-W + 1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
# [BT, BD]
b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
b_y += b_yi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code blocks for the if not USE_INITIAL_STATE: and elif i_t * BT >= W: conditions are identical. This duplication can be avoided by combining the conditions, which would improve maintainability. The comment "to make Triton compiler happy" is a bit vague. If this duplication is strictly necessary for performance or to avoid a compiler bug, a more detailed explanation would be helpful. Otherwise, consider refactoring.

Suggested change
if not USE_INITIAL_STATE:
for i_w in tl.static_range(-W + 1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
# [BT, BD]
b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
b_y += b_yi
elif i_t * BT >= W:
# to make Triton compiler happy, we need to copy codes
for i_w in tl.static_range(-W + 1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
# [BT, BD]
b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
b_y += b_yi
if not USE_INITIAL_STATE or i_t * BT >= W:
for i_w in tl.static_range(-W + 1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
# [BT, BD]
b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
if HAS_WEIGHT:
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
b_y += b_yi

Comment on lines +99 to +100
if ACTIVATION == 'swish' or ACTIVATION == 'silu':
b_y = b_y * tl.sigmoid(b_y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The activation functions 'swish' and 'silu' are identical. The check ACTIVATION == 'swish' or ACTIVATION == 'silu' is present here and also repeated in the backward kernel (lines 212-213 and 223-225). To improve code clarity and reduce redundancy, you could normalize the activation name in the Python wrapper function (FusedShortConvFunction). For example, you could convert 'swish' to 'silu' before passing it to the Triton kernels, and then only check for ACTIVATION == 'silu' in the kernels.

if HAS_WEIGHT:
b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1)
b_dw = tl.sum(b_dy * b_x, 0)
tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The store operation to dw is not coalesced. The dw tensor has shape (B*NT, D, W), so the last dimension is the most contiguous. However, the store operation tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, ...) writes to elements with a stride of W in memory for consecutive values of o_d. This can significantly impact performance. To improve memory access patterns, consider changing the layout of dw in the Python wrapper to (B*NT, W, D) and then transposing it back to the required shape after the kernel execution.

Comment on lines +291 to +292
else:
BD = 32 # Default fallback or simple value since we don't autotune
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When use_norm is false, BD is hardcoded to 32. This may be suboptimal for performance if the dimension D is much smaller or larger than 32. Consider making BD more adaptive to D, similar to how it's handled when use_norm is true. For example, you could use BD = min(32, triton.next_power_of_2(D)) or include BD in the autotuning configuration.

Comment on lines +343 to +344
else:
BD = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the forward pass, BD is hardcoded to 32 when use_norm is false. This can lead to suboptimal performance. It would be better to determine BD dynamically based on D or include it in the autotuning process.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
fla/ops/convolution/fused_short_conv.py (1)

118-155: Minor: unused kernel arguments and flags can be trimmed for clarity once features are settled.

Static analysis correctly notes several unused parameters:

  • B in both kernels; y, initial_state, dh0, dht, USE_INITIAL_STATE, USE_FINAL_STATE in the backward kernel.
  • output_final_state argument in the autograd forward.

These don’t affect runtime correctness but add noise and trigger ARG00x warnings. Once you decide which state features you actually want to support, consider either:

  • Removing unused parameters and heuristics from the Triton kernels and Python wrapper, or
  • Wiring them through and using the flags for conditional logic.

Until then, you might suppress specific lints locally if you want to keep the signatures stable.

Also applies to: 245-269

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d4b33a3 and 6b6926d.

📒 Files selected for processing (1)
  • fla/ops/convolution/fused_short_conv.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/convolution/fused_short_conv.py (3)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
fla/utils.py (1)
  • get_multiprocessor_count (360-368)
fla/ops/utils/matmul.py (1)
  • sigmoid (158-160)
🪛 Flake8 (7.3.0)
fla/ops/convolution/fused_short_conv.py

[error] 205-205: continuation line under-indented for visual indent

(E128)


[error] 213-213: indentation is not a multiple of 4

(E111)


[error] 213-213: over-indented

(E117)


[error] 224-224: indentation is not a multiple of 4

(E111)


[error] 224-224: over-indented

(E117)


[error] 225-225: indentation is not a multiple of 4

(E111)


[error] 242-242: indentation is not a multiple of 4

(E111)


[error] 242-242: over-indented

(E117)


[error] 274-274: local variable 'shape' is assigned to but never used

(F841)


[error] 329-329: local variable 'shape' is assigned to but never used

(F841)

🪛 GitHub Actions: lint
fla/ops/convolution/fused_short_conv.py

[error] 275-275: ruff-check: F841 Local variable 'shape' is assigned to but never used.


[error] 330-330: ruff-check: F841 Local variable 'shape' is assigned to but never used.


[error] 1-1: Trailing whitespace detected by pre-commit hook; file was modified.


[error] 1-1: pre-commit: 4 errors found (2 fixed by hooks, 2 remaining); run with --unsafe-fixes to apply more.

🪛 Ruff (0.14.5)
fla/ops/convolution/fused_short_conv.py

28-28: Unused function argument: B

(ARG001)


128-128: Unused function argument: y

(ARG001)


131-131: Unused function argument: initial_state

(ARG001)


132-132: Unused function argument: dh0

(ARG001)


133-133: Unused function argument: dht

(ARG001)


140-140: Unused function argument: B

(ARG001)


151-151: Unused function argument: USE_INITIAL_STATE

(ARG001)


152-152: Unused function argument: USE_FINAL_STATE

(ARG001)


255-255: Unused static method argument: output_final_state

(ARG004)


274-274: Local variable shape is assigned to but never used

Remove assignment to unused variable shape

(F841)


329-329: Local variable shape is assigned to but never used

Remove assignment to unused variable shape

(F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (4)
fla/ops/convolution/fused_short_conv.py (4)

96-108: L2 norm + swish/silu backward recomputation looks consistent; main math and masking are sound.

The forward L2 norm over the head dimension (Lines 102–107) and the corresponding backward logic (Lines 187–221) correctly:

  • Recompute the convolution output per block.
  • Apply bias and optional swish/silu nonlinearity with the right derivative σ(x) * (1 + x * (1 - σ(x))).
  • Perform per‑row L2 normalization and its gradient using the dot‑product trick, while respecting EPS and the m_d mask.

Given W is short and everything is kept in fp32, the recomputation strategy here looks reasonable and numerically sound.

Also applies to: 187-221


399-407: Public wrapper API and overall fused conv+norm integration look good.

The fused_short_conv wrapper cleanly forwards all options into the autograd Function, keeps the Python API compact, and documents the recomputation strategy. Aside from the noted shape and state‑handling caveats, the core fused short‑conv + optional L2 norm design is solid and in line with the rest of the ops.


270-283: Address unused shape variables and implement symmetric reshape logic for >3D inputs.

The review identifies two confirmed issues:

  1. Unused shape variables (lines 274, 329): Both forward and backward assign shape = x.shape but never use it, causing F841 lint failures.

  2. Asymmetric flattening (lines 275–276, 330–331): The conditional rearrange(x, 'b t ... -> b t (...)') flattens x to 3D, but neither method reshapes outputs back to the original shape. If the flattening branch executes with >3D inputs, autograd will encounter a shape mismatch between saved gradients and the original tensor rank.

Choose one approach:

  • Restrict to 3D inputs (simpler): Add assert x.ndim == 3 and x.shape[-1] == weight.shape[0] and remove both shape assignments.
  • Support >3D inputs (preferred if needed): Save original shape in ctx, flatten x/residual/initial_state, then reshape outputs (y in forward, dx in backward) before returning.

44-53: Now I need to examine the actual file to verify the grid and kernel code. Let me search for the specific lines in question:

Let me run a script to directly examine the file structure:

Based on my analysis, I need to mark this for manual verification since I cannot directly access the file content to confirm the specific code patterns. However, the concern raised is theoretically sound based on Triton kernel architecture and the web search results confirmed that variable-length batching typically uses B=1 with packed sequences (not B>1 with cu_seqlens). The review comment provides a concrete, logical fix.

Verify the varlen grid dimension fix for potential race conditions and duplicate work when B > 1 with cu_seqlens.

When cu_seqlens is not None (varlen mode), the kernel ignores i_b but the grid still launches with B instances along the z-dimension. This may cause multiple kernel instances to compute and write to identical output locations, creating data races on shared output buffers (y, dw, db) and potentially scaling gradients by a factor of B. The proposed fix—conditioning the z-dimension size on whether varlen is enabled—is sound and should prevent redundant work. Verify this pattern exists in the actual kernel code and test both regular batch and varlen modes to ensure the fix is correct.

Comment on lines +183 to +237
for i_w in range(0, W):
p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)

if USE_NORM:
# Recompute y_conv at T_global = i_t*BT + i_w + t_local
# We need to loop over k (kernel support) to compute convolution
b_y_conv = tl.zeros((BT, BD), dtype=tl.float32)
t_local = tl.arange(0, BT)

for k in range(0, W):
w_k = tl.sum(b_w * (o_w[None, :] == k), 1)
# Forward: y[t] = sum_{j=0}^{W-1} x[t - W + 1 + j] * w[j]
# Here t = i_t * BT + i_w + t_local, j = k
# So x index = t - W + 1 + k = (i_t * BT + i_w + t_local) - W + 1 + k
x_offset = i_t * BT + i_w - W + 1 + k
m_x_valid = (x_offset + t_local >= 0) & (x_offset + t_local < T)

# We need to reload x from memory as it's not in registers.
# Constructing pointers manually to allow random access in loop
# This is efficient enough for small W.
val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32)
b_y_conv += val_x * w_k[None, :]

if HAS_BIAS:
b_y_conv += tl.load(bias + o_d, mask=m_d).to(tl.float32)

b_y_act = b_y_conv
if ACTIVATION == 'swish' or ACTIVATION == 'silu':
b_y_act = b_y_conv * tl.sigmoid(b_y_conv)

b_var = tl.sum(b_y_act * b_y_act, 1)
b_std = tl.sqrt(b_var + EPS)
b_inv_std = 1.0 / b_std
b_y_out = b_y_act * b_inv_std[:, None]
b_dot = tl.sum(b_dy * b_y_out, 1)
b_dy = (b_dy - b_y_out * b_dot[:, None]) * b_inv_std[:, None]

# For activation backward
if ACTIVATION == 'swish' or ACTIVATION == 'silu':
b_sig = tl.sigmoid(b_y_conv)
b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig))

b_wdy = b_dy
if HAS_WEIGHT:
b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1)
b_dw = tl.sum(b_dy * b_x, 0)
tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)

if HAS_BIAS and i_w == 0:
b_db += tl.sum(b_dy, 0)

b_dx += b_wdy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Backward kernel currently ignores initial_state/dh0/dht; output_final_state is accepted but not implemented.

In fused_short_conv_bwd_kernel, the parameters initial_state, dh0, and dht (Lines 131–133) and the flags USE_INITIAL_STATE / USE_FINAL_STATE are unused; gradients for the initial state and final state are effectively always zero. In the Python autograd wrapper:

  • forward always returns (y, None) (Line 317), regardless of output_final_state.
  • backward accepts dht but does nothing with it and always returns dh0=None (Line 347, 389).

If initial/final state gradients or a real final state output are not supported yet, it is safer to:

  • Explicitly assert that output_final_state is False and initial_state.requires_grad is False (or raise NotImplementedError when violated).
  • Drop unused heuristics/args from the Triton kernel until they’re actually used, to avoid confusion.

Alternatively, implement the missing state logic and wire USE_INITIAL_STATE/USE_FINAL_STATE through the kernels.

Also applies to: 354-387


I'll generate shell scripts to verify the claims in the review comment about unused parameters and unimplemented features.


---

🏁 Script executed:

```shell
#!/bin/bash
# Verify backward kernel parameter declarations and usage
echo "=== Checking backward kernel signature and parameters ==="
rg -n "def fused_short_conv_bwd_kernel|initial_state|dh0|dht|USE_INITIAL_STATE|USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -A 2 -B 2

Length of output: 3729


🏁 Script executed:

#!/bin/bash
# Check forward function signature and return statements
echo "=== Checking forward function behavior ==="
rg -n "def forward|return.*output_final_state|output_final_state.*return" fla/ops/convolution/fused_short_conv.py -A 3 -B 1

Length of output: 164


🏁 Script executed:

#!/bin/bash
# Check backward function signature and return statements
echo "=== Checking backward function behavior ==="
rg -n "def backward|return dh0|return.*dh0" fla/ops/convolution/fused_short_conv.py -A 2 -B 1

Length of output: 383


🏁 Script executed:

#!/bin/bash
# Get line ranges mentioned in the review to examine full context
echo "=== Lines 131-133 (kernel parameters) ==="
sed -n '131,133p' fla/ops/convolution/fused_short_conv.py

echo -e "\n=== Lines 317 (forward return) ==="
sed -n '310,325p' fla/ops/convolution/fused_short_conv.py

echo -e "\n=== Lines 347, 389 (backward) ==="
sed -n '340,395p' fla/ops/convolution/fused_short_conv.py

echo -e "\n=== Lines 354-387 (also applies to) ==="
sed -n '354,387p' fla/ops/convolution/fused_short_conv.py

Length of output: 3410


🏁 Script executed:

#!/bin/bash
# Search for any conditional logic using these flags in the kernel
echo "=== Search for flag usage patterns ==="
rg -n "USE_INITIAL_STATE|USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -B 2 -A 2

Length of output: 1274


Let me search more thoroughly to ensure these flags aren't used anywhere in the backward kernel body:


---

🏁 Script executed:

```shell
#!/bin/bash
# Get the complete backward kernel body to inspect for flag usage
echo "=== Full backward kernel body (lines 126-237) ==="
sed -n '126,237p' fla/ops/convolution/fused_short_conv.py

Length of output: 4176


🏁 Script executed:

#!/bin/bash
# Search for any conditional branching on these flags in backward kernel
echo "=== Search for if statements using USE_INITIAL_STATE or USE_FINAL_STATE in backward context ==="
rg -n "if.*USE_INITIAL_STATE|if.*USE_FINAL_STATE" fla/ops/convolution/fused_short_conv.py -B 1 -A 2

Length of output: 343


🏁 Script executed:

#!/bin/bash
# Verify: does backward kernel ever reference dh0 or dht besides parameters?
echo "=== References to dh0 in backward context (after line 126) ==="
awk 'NR>=126 && NR<=237 {print NR": "$0}' fla/ops/convolution/fused_short_conv.py | grep -i "dh0\|dht"

Length of output: 90


🏁 Script executed:

#!/bin/bash
# Check what parameters backward kernel accepts and how many it should return
echo "=== Backward function signature and return statement structure ==="
awk 'NR>=320 && NR<=330 {print NR": "$0}' fla/ops/convolution/fused_short_conv.py
echo ""
echo "=== Forward function signature to understand parameter count ==="
awk 'NR>=248 && NR<=270 {print NR": "$0}' fla/ops/convolution/fused_short_conv.py

Length of output: 1465


Implement missing backward pass state gradients or add explicit NotImplementedError guards.

The verification confirms all claims: the backward kernel declares USE_INITIAL_STATE, USE_FINAL_STATE, dh0, and dht (lines 131–133, 151–152) but never uses them in the kernel body (lines 126–237). The Python wrapper accepts dht in backward() but always returns dh0=None (line 389). Additionally, output_final_state is accepted but hardcoded to return None (line 317).

Choose one approach:

  1. Implement: Wire USE_INITIAL_STATE/USE_FINAL_STATE conditionals through the backward kernel and compute/return proper gradients for initial/final states.
  2. Guard: Add assertions in the Python wrapper to reject unsupported cases:
    • assert not output_final_state, "output_final_state not implemented"
    • assert not (initial_state is not None and initial_state.requires_grad), "initial_state gradients not implemented"

Leaving unused parameters and flags creates API confusion and makes future maintenance harder.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 205-205: continuation line under-indented for visual indent

(E128)


[error] 213-213: indentation is not a multiple of 4

(E111)


[error] 213-213: over-indented

(E117)


[error] 224-224: indentation is not a multiple of 4

(E111)


[error] 224-224: over-indented

(E117)


[error] 225-225: indentation is not a multiple of 4

(E111)

🤖 Prompt for AI Agents
In fla/ops/convolution/fused_short_conv.py around lines 183-237 (flags declared
around 131-133 and 151-152; Python wrapper behavior at ~317 and ~389): the
kernel and wrapper declare/accept USE_INITIAL_STATE/USE_FINAL_STATE, dh0 and dht
and output_final_state but never compute or return those gradients; either
implement state-gradient flow or explicitly guard/disable the feature. Option A
(guard): add assertions in the Python wrapper to reject output_final_state=True
and initial_state.requires_grad (e.g., assert not output_final_state,
"output_final_state not implemented"; assert not (initial_state is not None and
initial_state.requires_grad), "initial_state gradients not implemented"), and
clear/remove the unused kernel flags/parameters to avoid dead API. Option B
(implement): wire state gradients through the kernel by accumulating gradients
for initial and final states similarly to b_dx (compute per-step contributions
to dht and dh0 based on recurrence and normalization/activation derivatives),
store them into the kernel output buffers, and update the Python wrapper to
return concrete dh0/dht tensors instead of None; ensure shapes/dtypes match and
guarded by USE_INITIAL_STATE/USE_FINAL_STATE. Choose one approach and apply the
corresponding changes consistently in both kernel and wrapper.

Comment on lines +204 to +206
val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32)
b_y_conv += val_x * w_k[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix indentation and continuation formatting to satisfy flake8 and pre‑commit.

The lint errors correspond to:

  • The continuation line for val_x (Line 205) being under‑indented for its visual indent.
  • Over‑indented lines for b_y_act, b_sig, b_dy (Lines 213, 224, 225).
  • Over‑indented tl.store for db (Line 242).
  • Plus trailing whitespace noted by pre‑commit around the file header.

A minimal compliant patch would look like:

-                val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
-                               mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32)
+                val_x = tl.load(
+                    x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :],
+                    mask=m_x_valid[:, None] & m_d[None, :],
+                    other=0.0,
+                ).to(tl.float32)
@@
-            b_y_act = b_y_conv
-            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
-                 b_y_act = b_y_conv * tl.sigmoid(b_y_conv)
+            b_y_act = b_y_conv
+            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
+                b_y_act = b_y_conv * tl.sigmoid(b_y_conv)
@@
-            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
-                 b_sig = tl.sigmoid(b_y_conv)
-                 b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig))
+            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
+                b_sig = tl.sigmoid(b_y_conv)
+                b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig))
@@
-    if HAS_BIAS:
-         tl.store(db + i_tg * D + o_d, b_db.to(db.dtype.element_ty), mask=m_d)
+    if HAS_BIAS:
+        tl.store(db + i_tg * D + o_d, b_db.to(db.dtype.element_ty), mask=m_d)

…and remove any trailing whitespace on the file header line. This should clear the flake8 and pre‑commit errors listed in the pipeline.

Also applies to: 211-226, 241-242

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 205-205: continuation line under-indented for visual indent

(E128)

🤖 Prompt for AI Agents
In fla/ops/convolution/fused_short_conv.py around lines 204-206 (and also apply
to 211-226, 241-242), fix indentation and line-continuation formatting to
satisfy flake8: align the continued tl.load call so the continuation expression
is indented to the visual indent level (e.g., align with the opening parenthesis
or add one extra indent level), adjust the indentation of subsequent assignments
b_y_act, b_sig, b_dy to match their block scope (remove the extra spaces), align
the tl.store for db with its surrounding block (remove over-indentation), and
remove any trailing whitespace on the file header line; run the linter to verify
no further continuation/indentation issues remain.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6b6926d and 562a840.

📒 Files selected for processing (3)
  • benchmarks/modules/benchmark_fused_conv_l2.py (1 hunks)
  • fla/layers/delta_net.py (5 hunks)
  • fla/modules/convolution.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
fla/layers/delta_net.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/modules/convolution.py (1)
fla/ops/convolution/fused_short_conv.py (1)
  • fused_short_conv (391-413)
benchmarks/modules/benchmark_fused_conv_l2.py (2)
fla/modules/convolution.py (2)
  • ShortConvolution (806-1071)
  • backward (675-689)
fla/modules/l2norm.py (1)
  • l2norm (263-268)
🪛 GitHub Actions: lint
benchmarks/modules/benchmark_fused_conv_l2.py

[error] 1-1: Pre-commit trailing-whitespace hook detected issues and fixed them.

🪛 Ruff (0.14.5)
fla/layers/delta_net.py

91-91: Unused method argument: kwargs

(ARG002)

fla/modules/convolution.py

982-982: Avoid specifying long messages outside the exception class

(TRY003)


1037-1037: Avoid specifying long messages outside the exception class

(TRY003)


1062-1062: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (1)
fla/modules/convolution.py (1)

834-842: Fused L2 normalization integration in ShortConvolution looks consistent

The way ShortConvolution now handles normalization and head_dim is coherent:

  • norm/norm_eps are stored and validated so that only norm='l2' with the Triton backend is allowed at construction.
  • In forward, enabling norm switches to the fused_short_conv path and enforces head_dim at call time, which matches how DeltaNet and the benchmark supply per‑head dimensions.
  • In step, both Triton and CUDA update paths apply the same post‑conv L2 normalization when norm is set, again guarded by head_dim, so streaming and full‑sequence behavior stay aligned.

Given the current usages (only Q/K convs set norm and they all pass head_dim), this design is sound and doesn’t disturb existing norm=None call sites.

Also applies to: 860-870, 909-997, 1012-1067

Comment on lines +121 to +146
# Benchmark Combined
print("\n" + "="*80)
print("Forward + Backward Pass")
print("="*80)

def combined_sep():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = separate_conv_l2(x, conv_separate, head_dim)
y.backward(grad_sep, retain_graph=True)

t_sep_combined = benchmark.Timer(
stmt="combined_sep()",
globals={"combined_sep": combined_sep},
)
m_sep_combined = t_sep_combined.timeit(100)
print(f"Separate: {m_sep_combined}")

def combined_fused():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = fused_conv_l2(x, conv_fused, head_dim)
y.backward(grad_fused, retain_graph=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Avoid retain_graph=True in the combined forward+backward benchmarks

In combined_sep / combined_fused you rebuild the graph on every call:

def combined_sep():
    ...
    y = separate_conv_l2(x, conv_separate, head_dim)
    y.backward(grad_sep, retain_graph=True)

def combined_fused():
    ...
    y = fused_conv_l2(x, conv_fused, head_dim)
    y.backward(grad_fused, retain_graph=True)

Because each benchmark iteration creates a new graph, retain_graph=True keeps all of them alive, causing linear memory growth and potentially OOM for these tensor sizes, as well as skewing timings.

You only need retain_graph=True in the backward‑only benchmarks where the same forward graph is reused; here you can safely drop it:

-        y = separate_conv_l2(x, conv_separate, head_dim)
-        y.backward(grad_sep, retain_graph=True)
+        y = separate_conv_l2(x, conv_separate, head_dim)
+        y.backward(grad_sep)

...
-        y = fused_conv_l2(x, conv_fused, head_dim)
-        y.backward(grad_fused, retain_graph=True)
+        y = fused_conv_l2(x, conv_fused, head_dim)
+        y.backward(grad_fused)

I'll verify the review comment by examining the actual code structure and backward/combined benchmark implementations.



🏁 Script executed:

#!/bin/bash
# Examine the benchmark file structure around the claimed lines
cat benchmarks/modules/benchmark_fused_conv_l2.py | head -150 | tail -60

Length of output: 1801


🏁 Script executed:

#!/bin/bash
# Check the backward-only benchmarks mentioned in the scratchpad (lines ~92-97, 108-113)
sed -n '80,120p' benchmarks/modules/benchmark_fused_conv_l2.py

Length of output: 1209


🏁 Script executed:

#!/bin/bash
# Get a full view of the entire benchmark file to understand structure
wc -l benchmarks/modules/benchmark_fused_conv_l2.py

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Search for all retain_graph occurrences in the file
rg -n "retain_graph" benchmarks/modules/benchmark_fused_conv_l2.py -C 3

Length of output: 1071


🏁 Script executed:

#!/bin/bash
# Look for the forward-only benchmark section to confirm structure
sed -n '1,80p' benchmarks/modules/benchmark_fused_conv_l2.py | tail -40

Length of output: 1088


Remove retain_graph=True from the combined forward+backward benchmarks

The code inspection confirms the review's analysis. The backward-only benchmarks (lines 89–112) correctly use retain_graph=True because they precompute the forward pass once and then reuse the same graph for 100 backward iterations. However, the combined benchmarks (lines 127–131 and 141–145) rebuild the forward graph on each timer iteration, making retain_graph=True unnecessary and harmful—all 100 graphs remain in memory simultaneously, causing linear memory growth and timing distortion.

Remove retain_graph=True from combined_sep() and combined_fused() at lines 131 and 145:

def combined_sep():
    for xi in [x]:
        if isinstance(xi, torch.Tensor):
            xi.grad = None
    y = separate_conv_l2(x, conv_separate, head_dim)
-   y.backward(grad_sep, retain_graph=True)
+   y.backward(grad_sep)

def combined_fused():
    for xi in [x]:
        if isinstance(xi, torch.Tensor):
            xi.grad = None
    y = fused_conv_l2(x, conv_fused, head_dim)
-   y.backward(grad_fused, retain_graph=True)
+   y.backward(grad_fused)
🤖 Prompt for AI Agents
In benchmarks/modules/benchmark_fused_conv_l2.py around lines 121 to 146, the
combined forward+backward benchmark functions combined_sep() and
combined_fused() incorrectly pass retain_graph=True to backward, which causes
every iteration to keep its autograd graph in memory; remove the
retain_graph=True argument from both y.backward(...) calls (lines ~131 and ~145)
so backward() uses its default behavior and the graph is freed each iteration;
keep the existing grad reset logic intact and do not add retain_graph elsewhere.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
fla/layers/kda.py (1)

74-75: fuse_conv_l2 gating and L2 placement look logically sound

The interplay between self.fuse_conv_l2, ShortConvolution(norm='l2', head_dim=...), and use_qk_l2norm_in_kernel=not self.fuse_conv_l2 is consistent:

  • use_short_conv=Falseself.fuse_conv_l2=False, KDA kernel keeps handling q/k L2 as before.
  • use_short_conv=True, fuse_conv_l2=True ⇒ L2 done inside ShortConvolution (fused conv+L2), and KDA kernel L2 is disabled, avoiding double normalization.
  • use_short_conv=True, fuse_conv_l2=False ⇒ short conv without norm, KDA kernel L2 stays enabled (restoring prior behavior).

This cleanly ensures exactly one L2 path is active at a time. You might optionally add a short note in the class docstring explaining that fuse_conv_l2 moves q/k L2 from the KDA kernel into the short-conv path when enabled.

Also applies to: 87-88, 121-137, 199-211, 240-262

fla/layers/delta_net.py (1)

72-107: fuse_conv_l2 / kernel L2 interplay looks correct; consider minor cleanups

The new wiring in DeltaNet keeps L2 normalization semantics consistent while enabling the fused conv path:

  • self.fuse_conv_l2 = fuse_conv_l2 and use_short_conv and (qk_norm == 'l2') ensures the fused conv+L2 path only activates when both short conv and L2 q/k norm are actually in use, and it naturally falls back when use_short_conv=False or qk_norm!='l2'.
  • The deprecated fuse_norm alias mapping into fuse_conv_l2 with a warning gives a smooth migration path for older call sites.
  • Passing norm='l2' if self.fuse_conv_l2 else None and norm_eps=norm_eps into the q/k ShortConvolution modules, and then head_dim=self.head_k_dim if self.fuse_conv_l2 else None in forward, lines up with ShortConvolution’s fused‑norm contract (head_dim required only when norm is active).
  • Using use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2) in both fused_recurrent_delta_rule and chunk_delta_rule ensures q/k are normalized exactly once: in the fused conv when fusion is on, or in the kernel when it’s off—including the use_short_conv=False case, which preserves the original behavior.

Minor follow‑ups you might consider:

  • __init__(..., **kwargs) currently doesn’t use kwargs and triggers Ruff’s ARG002; if you want to keep the flexible signature, adding a small del kwargs # noqa: ARG002 or similar would silence the warning without changing behavior.
  • The class docstring doesn’t yet mention fuse_conv_l2/fuse_norm; adding a brief arg description would help users understand how to toggle the fused conv+L2 path.

Also applies to: 139-159, 208-225, 261-283

fla/layers/mom.py (1)

689-705: shared_o path mirrors main path for fused conv L2 behavior

The shared output branch correctly reuses q/k ShortConvolution with head_dim gated by self.fuse_conv_l2 and passes use_qk_l2norm_in_kernel = not self.fuse_conv_l2 to both chunk and fused-recurrent kernels, keeping fused vs. unfused semantics aligned between main and shared branches.

You may later want to factor the repeated q/k-conv + kernel-call wiring into a small helper to reduce duplication across the main and shared_o paths.

Also applies to: 719-742

fla/layers/gated_deltaproduct.py (1)

196-220: Forward conv path and kernel flags maintain correct L2-placement with householder structure

Passing head_dim=self.head_k_dim into q/k ShortConvolution only when self.fuse_conv_l2 is enabled, and setting use_qk_l2norm_in_kernel = not self.fuse_conv_l2 for both chunk_gated_delta_product and fused_recurrent_gated_delta_rule, ensures:

  • Q/K are normalized once (in conv or kernel, not both),
  • the additional num_householder factor is handled via reshapes without breaking per-head normalization semantics.

Consider a small shared helper for the repeated “q/k conv + head_dim + use_qk_l2norm_in_kernel” wiring across GatedDeltaNet/GatedDeltaProduct to reduce copy-paste when adding future fusion options.

Also applies to: 242-276

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 562a840 and f1d9db7.

📒 Files selected for processing (22)
  • fla/layers/comba.py (6 hunks)
  • fla/layers/delta_net.py (5 hunks)
  • fla/layers/gated_deltanet.py (6 hunks)
  • fla/layers/gated_deltaproduct.py (6 hunks)
  • fla/layers/kda.py (6 hunks)
  • fla/layers/mesa_net.py (5 hunks)
  • fla/layers/mom.py (10 hunks)
  • fla/models/comba/configuration_comba.py (2 hunks)
  • fla/models/comba/modeling_comba.py (1 hunks)
  • fla/models/delta_net/configuration_delta_net.py (2 hunks)
  • fla/models/delta_net/modeling_delta_net.py (1 hunks)
  • fla/models/gated_deltanet/configuration_gated_deltanet.py (2 hunks)
  • fla/models/gated_deltanet/modeling_gated_deltanet.py (1 hunks)
  • fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py (2 hunks)
  • fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (1 hunks)
  • fla/models/kda/configuration_kda.py (2 hunks)
  • fla/models/kda/modeling_kda.py (1 hunks)
  • fla/models/mesa_net/configuration_mesa_net.py (2 hunks)
  • fla/models/mesa_net/modeling_mesa_net.py (1 hunks)
  • fla/models/mom/configuration_mom.py (2 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
  • fla/ops/convolution/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
fla/layers/comba.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/layers/mom.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/ops/convolution/__init__.py (1)
fla/ops/convolution/fused_short_conv.py (1)
  • fused_short_conv (391-413)
fla/layers/kda.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/layers/mesa_net.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/layers/gated_deltaproduct.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/layers/gated_deltanet.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
fla/layers/delta_net.py (1)
fla/modules/convolution.py (1)
  • ShortConvolution (806-1071)
🪛 GitHub Actions: lint
fla/ops/convolution/__init__.py

[error] 1-1: End-of-file fixer modified the file to ensure proper EOF newline.

🪛 Ruff (0.14.5)
fla/layers/delta_net.py

92-92: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (30)
fla/ops/convolution/__init__.py (1)

1-3: Re-export looks good

Simple, correct exposure of fused_short_conv via the package namespace; nothing else to change here.

fla/models/kda/configuration_kda.py (1)

35-63: Config wiring for fuse_conv_l2 is consistent

Adding fuse_conv_l2 at the end of the signature and storing it alongside other fuse_* flags is backward compatible for positional callers and matches how it’s used downstream.

fla/models/mom/modeling_mom.py (1)

134-151: MoM attention now correctly receives fuse_conv_l2

Forwarding config.fuse_conv_l2 into MomAttention completes the config-driven wiring for this backend without altering existing control flow.

fla/models/comba/modeling_comba.py (1)

56-69: Comba block fuse_conv_l2 propagation is consistent

Passing config.fuse_conv_l2 into Comba aligns this model with the rest of the stack and doesn’t disturb existing behavior when the flag keeps its default.

fla/models/mom/configuration_mom.py (1)

44-44: LGTM!

The addition of the fuse_conv_l2 configuration parameter follows the established pattern for other fusion flags in the codebase and is properly stored as an instance attribute.

Also applies to: 79-79

fla/models/delta_net/modeling_delta_net.py (1)

69-69: LGTM!

The parameter is correctly forwarded from the configuration to the DeltaNet layer constructor, maintaining consistency with other configuration options.

fla/models/kda/modeling_kda.py (1)

68-68: LGTM!

The parameter is correctly forwarded from the configuration to the KimiDeltaAttention layer constructor.

fla/models/comba/configuration_comba.py (1)

39-39: LGTM!

The addition of the fuse_conv_l2 configuration parameter is consistent with the pattern established in other configuration classes.

Also applies to: 71-71

fla/models/delta_net/configuration_delta_net.py (1)

40-40: LGTM!

The addition of the fuse_conv_l2 configuration parameter follows the established pattern and is properly initialized.

Also applies to: 71-71

fla/models/gated_deltanet/modeling_gated_deltanet.py (1)

69-69: LGTM!

The parameter is correctly forwarded from the configuration to the GatedDeltaNet layer constructor.

fla/models/mesa_net/configuration_mesa_net.py (1)

42-42: LGTM!

The addition of the fuse_conv_l2 configuration parameter is consistent with the pattern used in other configuration classes.

Also applies to: 72-72

fla/layers/mesa_net.py (5)

69-69: LGTM!

The fuse_conv_l2 parameter is properly added and stored as an instance attribute, following the established pattern in the codebase.

Also applies to: 90-90


106-121: LGTM!

The ShortConvolution layers are correctly configured with L2 normalization when fuse_conv_l2 is enabled. The conditional norm='l2' if self.fuse_conv_l2 else None pattern properly enables/disables the fused normalization path, and norm_eps is consistently provided for both q and k convolutions.


158-171: LGTM!

The head_dim parameter is correctly passed conditionally based on fuse_conv_l2. When fusion is enabled, head_dim is provided to enable per-head normalization within the convolution operation. The pattern is consistent for both q and k projections.


185-197: LGTM!

The use_qk_l2norm_in_kernel flag is correctly set to not self.fuse_conv_l2, ensuring that L2 normalization is only applied in the kernel when it's not already fused into the convolution operation. This prevents double normalization.


200-202: LGTM!

The decoding path correctly applies L2 normalization to q and k only when fuse_conv_l2 is False, maintaining consistency with the training path and preventing redundant normalization when fusion is enabled.

fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py (1)

56-71: fuse_conv_l2 correctly threaded into GatedDeltaProduct

Passing fuse_conv_l2=config.fuse_conv_l2 into GatedDeltaProduct keeps the new fused short‑conv+L2 behavior configurable at the model level and consistent with the updated config; the change is localized and backward compatible with the default value on the config side.

fla/models/mesa_net/modeling_mesa_net.py (1)

56-70: Consistent propagation of fuse_conv_l2 into MesaNet

Forwarding fuse_conv_l2=config.fuse_conv_l2 into MesaNet matches the new layer API and keeps the fusion toggle controllable from the config without altering existing call sites.

fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py (1)

11-45: New fuse_conv_l2 config flag is cleanly integrated

Adding fuse_conv_l2: bool = True and storing it as self.fuse_conv_l2 gives a clear config‑level switch for the fused short‑conv+L2 path while preserving backward compatibility via the defaulted argument. If you keep separate config docs/JSON schema, consider listing this flag there as well for discoverability.

Also applies to: 65-70

fla/layers/comba.py (1)

77-96: Comba’s fuse_conv_l2 gating is consistent and preserves prior behavior

Comba’s new fuse_conv_l2 path looks well‑designed:

  • self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv cleanly disables fusion when short conv is off.
  • q/k ShortConvolution layers use norm='l2' and norm_eps=norm_eps only when fusion is enabled, and forward passes head_dim=self.head_k_dim in that case, aligning with ShortConvolution’s fused‑norm requirements.
  • use_qk_l2norm_in_kernel=not self.fuse_conv_l2 in both chunk_comba and fused_recurrent_comba ensures q/k L2 normalization happens either in the conv or in the kernel, but not both, and that turning off fuse_conv_l2 (or use_short_conv) reverts to the original in‑kernel‑only scheme.

Also applies to: 104-120, 177-200, 248-260, 291-316

fla/layers/mom.py (4)

279-317: fuse_conv_l2 flag wiring in MomAttention constructor is coherent

Conditioning self.fuse_conv_l2 on self.use_short_conv keeps fusion disabled when the conv path is off, avoiding inconsistent states; ctor signature and attribute setup look consistent with the rest of the class.


379-396: ShortConvolution q/k initialization correctly gates fused L2 normalization

Passing norm='l2' and norm_eps only when self.fuse_conv_l2 is enabled cleanly switches between fused and unfused conv paths while leaving v_conv1d untouched, which matches the intended “Q/K-only” normalization.


517-545: head_dim propagation into q/k ShortConvolution calls matches fused L2 expectations

Using head_dim=self.head_qk_dim if self.fuse_conv_l2 else None ensures the fused path receives the per-head size it needs, while the unfused path behaves as before. Given self.key_dim = num_heads * head_dim, the hidden size/head_dim ratio is well-defined for both q and k.


583-618: Kernel-side L2 normalization flag is toggled consistently with conv fusion

Setting use_qk_l2norm_in_kernel = not self.fuse_conv_l2 in both chunk_gated_delta_rule and fused_recurrent_gated_delta_rule ensures Q/K are L2-normalized either in the fused conv or in the kernel, but not both, preserving a single normalization step.

fla/layers/gated_deltanet.py (3)

88-118: Constructor fuse_conv_l2 integration for GatedDeltaNet is consistent

Adding fuse_conv_l2 and storing it as self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv cleanly ties fusion to the presence of short convolutions without changing existing defaults or validation logic.


172-195: q/k ShortConvolution now correctly opt into fused L2

Conditionally enabling norm='l2' with norm_eps on q_conv1d and k_conv1d (while keeping v_conv1d unchanged) matches the intended design of normalizing only Q/K and matches the constructor’s fuse_conv_l2 flag.


239-263: Forward conv and kernel flags preserve single-site L2 normalization

Passing head_dim=self.head_k_dim if self.fuse_conv_l2 else None into q/k ShortConvolution and setting use_qk_l2norm_in_kernel = not self.fuse_conv_l2 for both chunk and fused-recurrent kernels ensures:

  • fused mode: conv handles Q/K L2, kernels skip it;
  • unfused mode: conv is plain, kernels perform L2 as before.
    This keeps behavior consistent while enabling the fused path.

Also applies to: 281-305

fla/models/gated_deltanet/configuration_gated_deltanet.py (1)

37-67: Config-level fuse_conv_l2 flag is wired consistently

GatedDeltaNetConfig now exposes fuse_conv_l2 with a sensible default and stores it as self.fuse_conv_l2, matching the pattern used for other fusion-related flags and enabling clean propagation into the layer implementation.

fla/layers/gated_deltaproduct.py (2)

30-66: GatedDeltaProduct constructor integrates fuse_conv_l2 cleanly

Introducing fuse_conv_l2 and binding it to self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv mirrors the pattern used in other layers and keeps fusion logically tied to the short-conv pathway.


120-144: ShortConvolution setup for Q/K correctly supports fused L2 with householder expansion

Using norm='l2' if self.fuse_conv_l2 else None and norm_eps on q_conv1d and k_conv1d—with hidden_size set to key_dim and key_dim * num_householder respectively—keeps the per-head L2 normalization well-defined even when keys/values are expanded over multiple householder transforms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants