-
Notifications
You must be signed in to change notification settings - Fork 322
Add fused short convolution kernel with L2 norm #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sustcsonglin
wants to merge
3
commits into
main
Choose a base branch
from
fuse-conv-l2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| import torch | ||
| from einops import rearrange | ||
|
|
||
| from fla.modules.convolution import ShortConvolution | ||
| from fla.modules.l2norm import l2norm | ||
| from fla.utils import device | ||
|
|
||
| def separate_conv_l2(x, conv, head_dim): | ||
| """Separate Conv + L2 Norm""" | ||
| y, _ = conv(x) | ||
| y = rearrange(y, 'b t (h d) -> b t h d', d=head_dim) | ||
| y = l2norm(y, eps=1e-5) | ||
| y = rearrange(y, 'b t h d -> b t (h d)') | ||
| return y | ||
|
|
||
| def fused_conv_l2(x, conv_fused, head_dim): | ||
| """Fused Conv + L2 Norm""" | ||
| y, _ = conv_fused(x, head_dim=head_dim) | ||
| return y | ||
|
|
||
| if __name__ == "__main__": | ||
| import torch.utils.benchmark as benchmark | ||
|
|
||
| # Test configurations | ||
| B, T, D, W = 4, 2048, 2048, 4 | ||
| H = 16 | ||
| head_dim = D // H | ||
|
|
||
| print("="*80) | ||
| print(f"Benchmarking Conv + L2 Norm: B={B}, T={T}, D={D}, W={W}, H={H}, head_dim={head_dim}") | ||
| print("="*80) | ||
|
|
||
| dtype = torch.bfloat16 | ||
|
|
||
| # Create input | ||
| x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) | ||
|
|
||
| # Separate Conv (no norm) | ||
| conv_separate = ShortConvolution( | ||
| hidden_size=D, | ||
| kernel_size=W, | ||
| bias=False, | ||
| activation='silu', | ||
| norm=None, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| # Fused Conv + L2 Norm | ||
| conv_fused = ShortConvolution( | ||
| hidden_size=D, | ||
| kernel_size=W, | ||
| bias=False, | ||
| activation='silu', | ||
| norm='l2', | ||
| norm_eps=1e-5, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| # Copy weights | ||
| conv_fused.weight.data.copy_(conv_separate.weight.data) | ||
|
|
||
| # Benchmark Forward | ||
| print("\n" + "="*80) | ||
| print("Forward Pass") | ||
| print("="*80) | ||
|
|
||
| t_sep_fwd = benchmark.Timer( | ||
| stmt="separate_conv_l2(x, conv, head_dim)", | ||
| globals={"separate_conv_l2": separate_conv_l2, "x": x, "conv": conv_separate, "head_dim": head_dim}, | ||
| ) | ||
| m_sep_fwd = t_sep_fwd.timeit(100) | ||
| print(f"Separate: {m_sep_fwd}") | ||
|
|
||
| t_fused_fwd = benchmark.Timer( | ||
| stmt="fused_conv_l2(x, conv, head_dim)", | ||
| globals={"fused_conv_l2": fused_conv_l2, "x": x, "conv": conv_fused, "head_dim": head_dim}, | ||
| ) | ||
| m_fused_fwd = t_fused_fwd.timeit(100) | ||
| print(f"Fused: {m_fused_fwd}") | ||
|
|
||
| # Benchmark Backward | ||
| print("\n" + "="*80) | ||
| print("Backward Pass") | ||
| print("="*80) | ||
|
|
||
| # Pre-compute forward for backward benchmark | ||
| y_sep = separate_conv_l2(x, conv_separate, head_dim) | ||
| grad_sep = torch.randn_like(y_sep) | ||
|
|
||
| def backward_sep(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y_sep.backward(grad_sep, retain_graph=True) | ||
|
|
||
| t_sep_bwd = benchmark.Timer( | ||
| stmt="backward_sep()", | ||
| globals={"backward_sep": backward_sep}, | ||
| ) | ||
| m_sep_bwd = t_sep_bwd.timeit(100) | ||
| print(f"Separate: {m_sep_bwd}") | ||
|
|
||
| y_fused = fused_conv_l2(x, conv_fused, head_dim) | ||
| grad_fused = torch.randn_like(y_fused) | ||
|
|
||
| def backward_fused(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y_fused.backward(grad_fused, retain_graph=True) | ||
|
|
||
| t_fused_bwd = benchmark.Timer( | ||
| stmt="backward_fused()", | ||
| globals={"backward_fused": backward_fused}, | ||
| ) | ||
| m_fused_bwd = t_fused_bwd.timeit(100) | ||
| print(f"Fused: {m_fused_bwd}") | ||
|
|
||
| # Benchmark Combined | ||
| print("\n" + "="*80) | ||
| print("Forward + Backward Pass") | ||
| print("="*80) | ||
|
|
||
| def combined_sep(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y = separate_conv_l2(x, conv_separate, head_dim) | ||
| y.backward(grad_sep, retain_graph=True) | ||
|
|
||
| t_sep_combined = benchmark.Timer( | ||
| stmt="combined_sep()", | ||
| globals={"combined_sep": combined_sep}, | ||
| ) | ||
| m_sep_combined = t_sep_combined.timeit(100) | ||
| print(f"Separate: {m_sep_combined}") | ||
|
|
||
| def combined_fused(): | ||
| for xi in [x]: | ||
| if isinstance(xi, torch.Tensor): | ||
| xi.grad = None | ||
| y = fused_conv_l2(x, conv_fused, head_dim) | ||
| y.backward(grad_fused, retain_graph=True) | ||
|
|
||
| t_fused_combined = benchmark.Timer( | ||
| stmt="combined_fused()", | ||
| globals={"combined_fused": combined_fused}, | ||
| ) | ||
| m_fused_combined = t_fused_combined.timeit(100) | ||
| print(f"Fused: {m_fused_combined}") | ||
|
|
||
| # Summary | ||
| time_sep_fwd = m_sep_fwd.median * 1000 | ||
| time_sep_bwd = m_sep_bwd.median * 1000 | ||
| time_sep_combined = m_sep_combined.median * 1000 | ||
|
|
||
| time_fused_fwd = m_fused_fwd.median * 1000 | ||
| time_fused_bwd = m_fused_bwd.median * 1000 | ||
| time_fused_combined = m_fused_combined.median * 1000 | ||
|
|
||
| print(f"\n{'='*80}") | ||
| print(f"{'Method':<35} {'Forward':<12} {'Backward':<12} {'Combined':<12} {'Speedup':<10}") | ||
| print("-"*80) | ||
| print(f"{'Separate (FLA)':<35} {time_sep_fwd:>10.3f}ms {time_sep_bwd:>10.3f}ms {time_sep_combined:>10.3f}ms {'1.00x':<10}") | ||
| print(f"{'Fused (Recompute)':<35} {time_fused_fwd:>10.3f}ms {time_fused_bwd:>10.3f}ms {time_fused_combined:>10.3f}ms {time_sep_combined/time_fused_combined:<10.2f}x") | ||
|
|
||
| speedup_fwd = (time_sep_fwd / time_fused_fwd - 1) * 100 | ||
| speedup_bwd = (time_sep_bwd / time_fused_bwd - 1) * 100 | ||
| speedup_combined = (time_sep_combined / time_fused_combined - 1) * 100 | ||
|
|
||
| print(f"\n{'='*80}") | ||
| print(f"Forward Speedup: {speedup_fwd:>+8.2f}%") | ||
| print(f"Backward Speedup: {speedup_bwd:>+8.2f}%") | ||
| print(f"Combined Speedup: {speedup_combined:>+8.2f}%") | ||
| print(f"\nMemory Saved: {B*T*D*2/1024/1024:.2f} MB per Conv layer (Y_act not stored)") | ||
| print(f"{'='*80}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Avoid
retain_graph=Truein the combined forward+backward benchmarksIn
combined_sep/combined_fusedyou rebuild the graph on every call:Because each benchmark iteration creates a new graph,
retain_graph=Truekeeps all of them alive, causing linear memory growth and potentially OOM for these tensor sizes, as well as skewing timings.You only need
retain_graph=Truein the backward‑only benchmarks where the same forward graph is reused; here you can safely drop it:I'll verify the review comment by examining the actual code structure and backward/combined benchmark implementations.
🏁 Script executed:
Length of output: 1801
🏁 Script executed:
Length of output: 1209
🏁 Script executed:
Length of output: 49
🏁 Script executed:
Length of output: 1071
🏁 Script executed:
Length of output: 1088
Remove
retain_graph=Truefrom the combined forward+backward benchmarksThe code inspection confirms the review's analysis. The backward-only benchmarks (lines 89–112) correctly use
retain_graph=Truebecause they precompute the forward pass once and then reuse the same graph for 100 backward iterations. However, the combined benchmarks (lines 127–131 and 141–145) rebuild the forward graph on each timer iteration, makingretain_graph=Trueunnecessary and harmful—all 100 graphs remain in memory simultaneously, causing linear memory growth and timing distortion.Remove
retain_graph=Truefromcombined_sep()andcombined_fused()at lines 131 and 145:def combined_sep(): for xi in [x]: if isinstance(xi, torch.Tensor): xi.grad = None y = separate_conv_l2(x, conv_separate, head_dim) - y.backward(grad_sep, retain_graph=True) + y.backward(grad_sep) def combined_fused(): for xi in [x]: if isinstance(xi, torch.Tensor): xi.grad = None y = fused_conv_l2(x, conv_fused, head_dim) - y.backward(grad_fused, retain_graph=True) + y.backward(grad_fused)🤖 Prompt for AI Agents