Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions benchmarks/modules/benchmark_fused_conv_l2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
from einops import rearrange

from fla.modules.convolution import ShortConvolution
from fla.modules.l2norm import l2norm
from fla.utils import device

def separate_conv_l2(x, conv, head_dim):
"""Separate Conv + L2 Norm"""
y, _ = conv(x)
y = rearrange(y, 'b t (h d) -> b t h d', d=head_dim)
y = l2norm(y, eps=1e-5)
y = rearrange(y, 'b t h d -> b t (h d)')
return y

def fused_conv_l2(x, conv_fused, head_dim):
"""Fused Conv + L2 Norm"""
y, _ = conv_fused(x, head_dim=head_dim)
return y

if __name__ == "__main__":
import torch.utils.benchmark as benchmark

# Test configurations
B, T, D, W = 4, 2048, 2048, 4
H = 16
head_dim = D // H

print("="*80)
print(f"Benchmarking Conv + L2 Norm: B={B}, T={T}, D={D}, W={W}, H={H}, head_dim={head_dim}")
print("="*80)

dtype = torch.bfloat16

# Create input
x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True)

# Separate Conv (no norm)
conv_separate = ShortConvolution(
hidden_size=D,
kernel_size=W,
bias=False,
activation='silu',
norm=None,
device=device,
dtype=dtype,
)

# Fused Conv + L2 Norm
conv_fused = ShortConvolution(
hidden_size=D,
kernel_size=W,
bias=False,
activation='silu',
norm='l2',
norm_eps=1e-5,
device=device,
dtype=dtype,
)

# Copy weights
conv_fused.weight.data.copy_(conv_separate.weight.data)

# Benchmark Forward
print("\n" + "="*80)
print("Forward Pass")
print("="*80)

t_sep_fwd = benchmark.Timer(
stmt="separate_conv_l2(x, conv, head_dim)",
globals={"separate_conv_l2": separate_conv_l2, "x": x, "conv": conv_separate, "head_dim": head_dim},
)
m_sep_fwd = t_sep_fwd.timeit(100)
print(f"Separate: {m_sep_fwd}")

t_fused_fwd = benchmark.Timer(
stmt="fused_conv_l2(x, conv, head_dim)",
globals={"fused_conv_l2": fused_conv_l2, "x": x, "conv": conv_fused, "head_dim": head_dim},
)
m_fused_fwd = t_fused_fwd.timeit(100)
print(f"Fused: {m_fused_fwd}")

# Benchmark Backward
print("\n" + "="*80)
print("Backward Pass")
print("="*80)

# Pre-compute forward for backward benchmark
y_sep = separate_conv_l2(x, conv_separate, head_dim)
grad_sep = torch.randn_like(y_sep)

def backward_sep():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y_sep.backward(grad_sep, retain_graph=True)

t_sep_bwd = benchmark.Timer(
stmt="backward_sep()",
globals={"backward_sep": backward_sep},
)
m_sep_bwd = t_sep_bwd.timeit(100)
print(f"Separate: {m_sep_bwd}")

y_fused = fused_conv_l2(x, conv_fused, head_dim)
grad_fused = torch.randn_like(y_fused)

def backward_fused():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y_fused.backward(grad_fused, retain_graph=True)

t_fused_bwd = benchmark.Timer(
stmt="backward_fused()",
globals={"backward_fused": backward_fused},
)
m_fused_bwd = t_fused_bwd.timeit(100)
print(f"Fused: {m_fused_bwd}")

# Benchmark Combined
print("\n" + "="*80)
print("Forward + Backward Pass")
print("="*80)

def combined_sep():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = separate_conv_l2(x, conv_separate, head_dim)
y.backward(grad_sep, retain_graph=True)

t_sep_combined = benchmark.Timer(
stmt="combined_sep()",
globals={"combined_sep": combined_sep},
)
m_sep_combined = t_sep_combined.timeit(100)
print(f"Separate: {m_sep_combined}")

def combined_fused():
for xi in [x]:
if isinstance(xi, torch.Tensor):
xi.grad = None
y = fused_conv_l2(x, conv_fused, head_dim)
y.backward(grad_fused, retain_graph=True)

Comment on lines +121 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Avoid retain_graph=True in the combined forward+backward benchmarks

In combined_sep / combined_fused you rebuild the graph on every call:

def combined_sep():
    ...
    y = separate_conv_l2(x, conv_separate, head_dim)
    y.backward(grad_sep, retain_graph=True)

def combined_fused():
    ...
    y = fused_conv_l2(x, conv_fused, head_dim)
    y.backward(grad_fused, retain_graph=True)

Because each benchmark iteration creates a new graph, retain_graph=True keeps all of them alive, causing linear memory growth and potentially OOM for these tensor sizes, as well as skewing timings.

You only need retain_graph=True in the backward‑only benchmarks where the same forward graph is reused; here you can safely drop it:

-        y = separate_conv_l2(x, conv_separate, head_dim)
-        y.backward(grad_sep, retain_graph=True)
+        y = separate_conv_l2(x, conv_separate, head_dim)
+        y.backward(grad_sep)

...
-        y = fused_conv_l2(x, conv_fused, head_dim)
-        y.backward(grad_fused, retain_graph=True)
+        y = fused_conv_l2(x, conv_fused, head_dim)
+        y.backward(grad_fused)

I'll verify the review comment by examining the actual code structure and backward/combined benchmark implementations.



🏁 Script executed:

#!/bin/bash
# Examine the benchmark file structure around the claimed lines
cat benchmarks/modules/benchmark_fused_conv_l2.py | head -150 | tail -60

Length of output: 1801


🏁 Script executed:

#!/bin/bash
# Check the backward-only benchmarks mentioned in the scratchpad (lines ~92-97, 108-113)
sed -n '80,120p' benchmarks/modules/benchmark_fused_conv_l2.py

Length of output: 1209


🏁 Script executed:

#!/bin/bash
# Get a full view of the entire benchmark file to understand structure
wc -l benchmarks/modules/benchmark_fused_conv_l2.py

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Search for all retain_graph occurrences in the file
rg -n "retain_graph" benchmarks/modules/benchmark_fused_conv_l2.py -C 3

Length of output: 1071


🏁 Script executed:

#!/bin/bash
# Look for the forward-only benchmark section to confirm structure
sed -n '1,80p' benchmarks/modules/benchmark_fused_conv_l2.py | tail -40

Length of output: 1088


Remove retain_graph=True from the combined forward+backward benchmarks

The code inspection confirms the review's analysis. The backward-only benchmarks (lines 89–112) correctly use retain_graph=True because they precompute the forward pass once and then reuse the same graph for 100 backward iterations. However, the combined benchmarks (lines 127–131 and 141–145) rebuild the forward graph on each timer iteration, making retain_graph=True unnecessary and harmful—all 100 graphs remain in memory simultaneously, causing linear memory growth and timing distortion.

Remove retain_graph=True from combined_sep() and combined_fused() at lines 131 and 145:

def combined_sep():
    for xi in [x]:
        if isinstance(xi, torch.Tensor):
            xi.grad = None
    y = separate_conv_l2(x, conv_separate, head_dim)
-   y.backward(grad_sep, retain_graph=True)
+   y.backward(grad_sep)

def combined_fused():
    for xi in [x]:
        if isinstance(xi, torch.Tensor):
            xi.grad = None
    y = fused_conv_l2(x, conv_fused, head_dim)
-   y.backward(grad_fused, retain_graph=True)
+   y.backward(grad_fused)
🤖 Prompt for AI Agents
In benchmarks/modules/benchmark_fused_conv_l2.py around lines 121 to 146, the
combined forward+backward benchmark functions combined_sep() and
combined_fused() incorrectly pass retain_graph=True to backward, which causes
every iteration to keep its autograd graph in memory; remove the
retain_graph=True argument from both y.backward(...) calls (lines ~131 and ~145)
so backward() uses its default behavior and the graph is freed each iteration;
keep the existing grad reset logic intact and do not add retain_graph elsewhere.

t_fused_combined = benchmark.Timer(
stmt="combined_fused()",
globals={"combined_fused": combined_fused},
)
m_fused_combined = t_fused_combined.timeit(100)
print(f"Fused: {m_fused_combined}")

# Summary
time_sep_fwd = m_sep_fwd.median * 1000
time_sep_bwd = m_sep_bwd.median * 1000
time_sep_combined = m_sep_combined.median * 1000

time_fused_fwd = m_fused_fwd.median * 1000
time_fused_bwd = m_fused_bwd.median * 1000
time_fused_combined = m_fused_combined.median * 1000

print(f"\n{'='*80}")
print(f"{'Method':<35} {'Forward':<12} {'Backward':<12} {'Combined':<12} {'Speedup':<10}")
print("-"*80)
print(f"{'Separate (FLA)':<35} {time_sep_fwd:>10.3f}ms {time_sep_bwd:>10.3f}ms {time_sep_combined:>10.3f}ms {'1.00x':<10}")
print(f"{'Fused (Recompute)':<35} {time_fused_fwd:>10.3f}ms {time_fused_bwd:>10.3f}ms {time_fused_combined:>10.3f}ms {time_sep_combined/time_fused_combined:<10.2f}x")

speedup_fwd = (time_sep_fwd / time_fused_fwd - 1) * 100
speedup_bwd = (time_sep_bwd / time_fused_bwd - 1) * 100
speedup_combined = (time_sep_combined / time_fused_combined - 1) * 100

print(f"\n{'='*80}")
print(f"Forward Speedup: {speedup_fwd:>+8.2f}%")
print(f"Backward Speedup: {speedup_bwd:>+8.2f}%")
print(f"Combined Speedup: {speedup_combined:>+8.2f}%")
print(f"\nMemory Saved: {B*T*D*2/1024/1024:.2f} MB per Conv layer (Y_act not stored)")
print(f"{'='*80}")
12 changes: 10 additions & 2 deletions fla/layers/comba.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
conv_bias: bool = False,
layer_idx: int = None,
norm_eps: float = 1e-5,
fuse_conv_l2: bool = True,
**kwargs,
) -> Comba:
super().__init__()
Expand All @@ -106,6 +107,7 @@ def __init__(
self.use_inner_decay = use_inner_decay
self.conv_size = conv_size
self.conv_bias = conv_bias
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv

self.head_dim = head_dim
self.num_heads = num_heads
Expand Down Expand Up @@ -179,12 +181,16 @@ def __init__(
kernel_size=conv_size,
bias=conv_bias,
activation='silu',
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.k_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation='silu',
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.v_conv1d = ShortConvolution(
hidden_size=self.value_dim,
Expand Down Expand Up @@ -243,12 +249,14 @@ def forward(
cache=conv_state_q,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
)
k, conv_state_k = self.k_conv1d(
x=self.k_proj(hidden_states),
cache=conv_state_k,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
)
v, conv_state_v = self.v_conv1d(
x=self.v_proj(hidden_states),
Expand Down Expand Up @@ -291,7 +299,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_comba(
Expand All @@ -304,7 +312,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
Expand Down
20 changes: 18 additions & 2 deletions fla/layers/delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,23 @@ def __init__(
qk_activation: str = 'silu',
qk_norm: str = 'l2',
norm_eps: float = 1e-5,
fuse_conv_l2: bool = True,
fuse_norm: bool | None = None,
**kwargs,
) -> DeltaNet:
super().__init__()

self.mode = mode
self.qk_activation = qk_activation
self.qk_norm = qk_norm
if fuse_norm is not None:
warnings.warn(
"`fuse_norm` is deprecated for DeltaNet; use `fuse_conv_l2` to control the fused "
"ShortConvolution + L2 kernel.",
stacklevel=2,
)
fuse_conv_l2 = fuse_norm
self.fuse_conv_l2 = fuse_conv_l2 and use_short_conv and (qk_norm == 'l2')

assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
assert self.qk_norm in ['l2', 'sum']
Expand Down Expand Up @@ -136,12 +146,16 @@ def __init__(
kernel_size=conv_size,
bias=conv_bias,
activation='silu' if qk_activation == 'silu' else None,
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.k_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation='silu' if qk_activation == 'silu' else None,
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.v_conv1d = ShortConvolution(
hidden_size=self.value_dim,
Expand Down Expand Up @@ -200,12 +214,14 @@ def forward(
cache=conv_state_q,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None
)
k, conv_state_k = self.k_conv1d(
x=self.k_proj(hidden_states),
cache=conv_state_k,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None
)
v, conv_state_v = self.v_conv1d(
x=self.v_proj(hidden_states),
Expand Down Expand Up @@ -252,7 +268,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2'),
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2),
)
elif mode == 'chunk':
o, recurrent_state = chunk_delta_rule(
Expand All @@ -263,7 +279,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2'),
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2),
)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
Expand Down
12 changes: 10 additions & 2 deletions fla/layers/gated_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
conv_bias: bool = False,
layer_idx: int = None,
norm_eps: float = 1e-5,
fuse_conv_l2: bool = True,
**kwargs,
) -> GatedDeltaNet:
super().__init__()
Expand All @@ -113,6 +114,7 @@ def __init__(
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv

self.head_dim = head_dim
self.num_heads = num_heads
Expand Down Expand Up @@ -174,12 +176,16 @@ def __init__(
kernel_size=conv_size,
bias=conv_bias,
activation='silu',
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.k_conv1d = ShortConvolution(
hidden_size=self.key_dim,
kernel_size=conv_size,
bias=conv_bias,
activation='silu',
norm='l2' if self.fuse_conv_l2 else None,
norm_eps=norm_eps,
)
self.v_conv1d = ShortConvolution(
hidden_size=self.value_dim,
Expand Down Expand Up @@ -239,12 +245,14 @@ def forward(
cache=conv_state_q,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
)
k, conv_state_k = self.k_conv1d(
x=self.k_proj(hidden_states),
cache=conv_state_k,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
)
v, conv_state_v = self.v_conv1d(
x=self.v_proj(hidden_states),
Expand Down Expand Up @@ -280,7 +288,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_gated_delta_rule(
Expand All @@ -292,7 +300,7 @@ def forward(
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
Expand Down
Loading
Loading