From 6b6926d40c99b89eecf8a99869b2dc3da2880279 Mon Sep 17 00:00:00 2001 From: Songlin Yang Date: Mon, 24 Nov 2025 10:02:30 +0000 Subject: [PATCH 1/3] Add fused short convolution kernel with L2 norm --- fla/ops/convolution/fused_short_conv.py | 413 ++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 fla/ops/convolution/fused_short_conv.py diff --git a/fla/ops/convolution/fused_short_conv.py b/fla/ops/convolution/fused_short_conv.py new file mode 100644 index 000000000..8ba273edf --- /dev/null +++ b/fla/ops/convolution/fused_short_conv.py @@ -0,0 +1,413 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.utils import prepare_chunk_indices +from fla.utils import get_multiprocessor_count, input_guard + +@triton.heuristics({ + 'HAS_WEIGHT': lambda args: args['weight'] is not None, + 'HAS_BIAS': lambda args: args['bias'] is not None, + 'HAS_RESIDUAL': lambda args: args['residual'] is not None, + 'USE_INITIAL_STATE': lambda args: args['initial_state'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit +def fused_short_conv_fwd_kernel( + x, + y, + weight, + bias, + residual, + cu_seqlens, + initial_state, + chunk_indices, + B, + T, + D: tl.constexpr, + W: tl.constexpr, + BT: tl.constexpr, + BW: tl.constexpr, + BD: tl.constexpr, + EPS: tl.constexpr, + ACTIVATION: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_NORM: tl.constexpr, +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + i_n = i_b + bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64) + + o_d = i_d * BD + tl.arange(0, BD) + o_w = tl.arange(0, BW) + W - BW + m_d = o_d < D + m_w = o_w >= 0 + + if HAS_WEIGHT: + # [BD, BW] + b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0).to(tl.float32) + + b_y = tl.zeros((BT, BD), dtype=tl.float32) + if not USE_INITIAL_STATE: + for i_w in tl.static_range(-W + 1, 1): + p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) + # [BT, BD] + b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) + if HAS_WEIGHT: + b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) + b_y += b_yi + elif i_t * BT >= W: + # to make Triton compiler happy, we need to copy codes + for i_w in tl.static_range(-W + 1, 1): + p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) + # [BT, BD] + b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32) + if HAS_WEIGHT: + b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) + b_y += b_yi + else: + o_t = i_t * BT + tl.arange(0, BT) + for i_w in tl.static_range(-W + 1, 1): + o_x = o_t + i_w + m_x = ((o_x >= 0) & (o_x < T))[:, None] & m_d + m_c = ((o_x + W >= 0) & (o_x < 0))[:, None] & m_d + + b_yi = tl.load(x + bos * D + o_x[:, None] * D + o_d, mask=m_x, other=0).to(tl.float32) + + b_yi += tl.load(initial_state + i_n * D*W + o_d * W + (o_x + W)[:, None], mask=m_c, other=0).to(tl.float32) + + if HAS_WEIGHT: + b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1) + b_y += b_yi + + if HAS_BIAS: + b_y += tl.load(bias + o_d, mask=m_d).to(tl.float32) + + if ACTIVATION == 'swish' or ACTIVATION == 'silu': + b_y = b_y * tl.sigmoid(b_y) + + if USE_NORM: + # L2 norm over the head dimension (BD) + # b_y is [BT, BD] + b_var = tl.sum(b_y * b_y, axis=1) + b_std = tl.sqrt(b_var + EPS) + b_y = b_y / b_std[:, None] + + if HAS_RESIDUAL: + p_residual = tl.make_block_ptr(residual + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + b_residual = tl.load(p_residual, boundary_check=(0, 1)) + b_y += b_residual + + p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + tl.store(p_y, tl.cast(b_y, dtype=p_y.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'HAS_WEIGHT': lambda args: args['dw'] is not None, + 'HAS_BIAS': lambda args: args['db'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit +def fused_short_conv_bwd_kernel( + x, + y, + weight, + bias, + initial_state, + dh0, + dht, + dy, + dx, + dw, + db, + cu_seqlens, + chunk_indices, + B, + T, + D: tl.constexpr, + W: tl.constexpr, + BT: tl.constexpr, + BW: tl.constexpr, + BD: tl.constexpr, + EPS: tl.constexpr, + ACTIVATION: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_NORM: tl.constexpr, +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + i_tg = i_b * tl.num_programs(1) + i_t + i_n = i_b + bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64) + + o_d = i_d * BD + tl.arange(0, BD) + o_w = tl.arange(0, BW) + W - BW + m_d = o_d < D + m_w = o_w >= 0 + + if HAS_WEIGHT: + b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0).to(tl.float32) + + if HAS_WEIGHT: + p_x = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + + b_dx = tl.zeros((BT, BD), dtype=tl.float32) + if HAS_BIAS: + b_db = tl.zeros((BD,), dtype=tl.float32) + + for i_w in range(0, W): + p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0)) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + + if USE_NORM: + # Recompute y_conv at T_global = i_t*BT + i_w + t_local + # We need to loop over k (kernel support) to compute convolution + b_y_conv = tl.zeros((BT, BD), dtype=tl.float32) + t_local = tl.arange(0, BT) + + for k in range(0, W): + w_k = tl.sum(b_w * (o_w[None, :] == k), 1) + # Forward: y[t] = sum_{j=0}^{W-1} x[t - W + 1 + j] * w[j] + # Here t = i_t * BT + i_w + t_local, j = k + # So x index = t - W + 1 + k = (i_t * BT + i_w + t_local) - W + 1 + k + x_offset = i_t * BT + i_w - W + 1 + k + m_x_valid = (x_offset + t_local >= 0) & (x_offset + t_local < T) + + # We need to reload x from memory as it's not in registers. + # Constructing pointers manually to allow random access in loop + # This is efficient enough for small W. + val_x = tl.load(x + bos * D + (x_offset + t_local)[:, None] * D + o_d[None, :], + mask=m_x_valid[:, None] & m_d[None, :], other=0.0).to(tl.float32) + b_y_conv += val_x * w_k[None, :] + + if HAS_BIAS: + b_y_conv += tl.load(bias + o_d, mask=m_d).to(tl.float32) + + b_y_act = b_y_conv + if ACTIVATION == 'swish' or ACTIVATION == 'silu': + b_y_act = b_y_conv * tl.sigmoid(b_y_conv) + + b_var = tl.sum(b_y_act * b_y_act, 1) + b_std = tl.sqrt(b_var + EPS) + b_inv_std = 1.0 / b_std + b_y_out = b_y_act * b_inv_std[:, None] + b_dot = tl.sum(b_dy * b_y_out, 1) + b_dy = (b_dy - b_y_out * b_dot[:, None]) * b_inv_std[:, None] + + # For activation backward + if ACTIVATION == 'swish' or ACTIVATION == 'silu': + b_sig = tl.sigmoid(b_y_conv) + b_dy = b_dy * b_sig * (1 + b_y_conv * (1 - b_sig)) + + b_wdy = b_dy + if HAS_WEIGHT: + b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1) + b_dw = tl.sum(b_dy * b_x, 0) + tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d) + + if HAS_BIAS and i_w == 0: + b_db += tl.sum(b_dy, 0) + + b_dx += b_wdy + + p_dx = tl.make_block_ptr(dx + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + tl.store(p_dx, tl.cast(b_dx, dtype=p_dx.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1)) + + if HAS_BIAS: + tl.store(db + i_tg * D + o_d, b_db.to(db.dtype.element_ty), mask=m_d) + + +class FusedShortConvFunction(torch.autograd.Function): + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + activation: str | None = None, + cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_norm: bool = False, + norm_eps: float = 1e-5, + head_dim: int | None = None, + ): + ctx.activation = activation + ctx.cu_seqlens = cu_seqlens + ctx.chunk_indices = chunk_indices + ctx.use_norm = use_norm + ctx.norm_eps = norm_eps + ctx.head_dim = head_dim + + # Save tensors for backward + # We use recomputation strategy: don't save y_act, recompute in backward + ctx.save_for_backward(x, weight, bias, residual, initial_state) + + shape = x.shape + if x.shape[-1] != weight.shape[0]: + x = rearrange(x, 'b t ... -> b t (...)') + B, T, D, W = *x.shape, weight.shape[1] + BT = min(64, triton.next_power_of_2(triton.cdiv(max(16, B*T), get_multiprocessor_count(x.device.index)))) + BW = triton.next_power_of_2(W) + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + # Determine BD + if use_norm: + assert head_dim is not None, "head_dim must be provided when use_norm is True" + BD = head_dim + # Check BD is power of 2? Triton prefers it. + # If not, next power of 2 and mask handles it. + BD = triton.next_power_of_2(head_dim) + else: + BD = 32 # Default fallback or simple value since we don't autotune + + y = torch.empty_like(x) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B) + + fused_short_conv_fwd_kernel[grid]( + x=x, + y=y, + weight=weight, + bias=bias, + residual=residual, + cu_seqlens=cu_seqlens, + initial_state=initial_state, + chunk_indices=chunk_indices, + B=B, + T=T, + D=D, + W=W, + BT=BT, + BW=BW, + BD=BD, + EPS=norm_eps, + ACTIVATION=activation, + USE_NORM=use_norm, + ) + return y, None # final_state not implemented for now + + @staticmethod + @input_guard + def backward(ctx, dy: torch.Tensor, dht: torch.Tensor | None = None): + x, weight, bias, residual, initial_state = ctx.saved_tensors + use_norm = ctx.use_norm + norm_eps = ctx.norm_eps + head_dim = ctx.head_dim + activation = ctx.activation + + # Similar setup + shape = x.shape + if x.shape[-1] != weight.shape[0]: + x = rearrange(x, 'b t ... -> b t (...)') + B, T, D = x.shape + W = weight.shape[1] + BT = min(64, triton.next_power_of_2(triton.cdiv(max(16, B*T), get_multiprocessor_count(x.device.index)))) + BW = triton.next_power_of_2(W) + if ctx.chunk_indices is None and ctx.cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(ctx.cu_seqlens, BT) + else: + chunk_indices = ctx.chunk_indices + NT = len(chunk_indices) if ctx.cu_seqlens is not None else triton.cdiv(T, BT) + if use_norm: + BD = triton.next_power_of_2(head_dim) + else: + BD = 32 + + dx = torch.empty_like(x) + dh0 = None # Not implemented + dr = dy if residual is not None else None + + # Always use recomputation strategy (best performance + memory efficiency) + y = None + + # Standard backward kernel + dw = weight.new_empty(B*NT, *weight.shape, dtype=torch.float) if weight is not None else None + db = bias.new_empty(B*NT, *bias.shape, dtype=torch.float) if bias is not None else None + + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B) + fused_short_conv_bwd_kernel[grid]( + x=x, + y=y, + weight=weight, + bias=bias, + initial_state=initial_state, + dh0=dh0, + dht=dht, + dy=dy, + dx=dx, + dw=dw, + db=db, + cu_seqlens=ctx.cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + D=D, + W=W, + BT=BT, + BW=BW, + BD=BD, + EPS=norm_eps, + ACTIVATION=activation, + USE_NORM=use_norm, + ) + + if weight is not None: + dw = dw.sum(0).to(weight) + if bias is not None: + db = db.sum(0).to(bias) + + return dx, dw, db, dr, dh0, None, None, None, None, None, None, None + +def fused_short_conv( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + activation: str | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_norm: bool = False, + norm_eps: float = 1e-5, + head_dim: int | None = None, +): + """ + Fused short convolution with optional L2 normalization. + + Uses recomputation strategy in backward: activations are recomputed on-the-fly + instead of being saved, providing both speed and memory benefits. + """ + return FusedShortConvFunction.apply( + x, weight, bias, residual, initial_state, output_final_state, activation, cu_seqlens, chunk_indices, use_norm, norm_eps, head_dim + ) From 562a840455f41989c49cd5f8304093a715b7412c Mon Sep 17 00:00:00 2001 From: Songlin Yang Date: Mon, 24 Nov 2025 10:07:35 +0000 Subject: [PATCH 2/3] add test and benchmark --- benchmarks/modules/benchmark_fused_conv_l2.py | 178 ++++++++++++++++++ fla/layers/delta_net.py | 12 +- fla/modules/convolution.py | 49 ++++- fla/ops/convolution/__init__.py | 0 4 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 benchmarks/modules/benchmark_fused_conv_l2.py create mode 100644 fla/ops/convolution/__init__.py diff --git a/benchmarks/modules/benchmark_fused_conv_l2.py b/benchmarks/modules/benchmark_fused_conv_l2.py new file mode 100644 index 000000000..ab45d9ba1 --- /dev/null +++ b/benchmarks/modules/benchmark_fused_conv_l2.py @@ -0,0 +1,178 @@ +import torch +from einops import rearrange + +from fla.modules.convolution import ShortConvolution +from fla.modules.l2norm import l2norm +from fla.utils import device + +def separate_conv_l2(x, conv, head_dim): + """Separate Conv + L2 Norm""" + y, _ = conv(x) + y = rearrange(y, 'b t (h d) -> b t h d', d=head_dim) + y = l2norm(y, eps=1e-5) + y = rearrange(y, 'b t h d -> b t (h d)') + return y + +def fused_conv_l2(x, conv_fused, head_dim): + """Fused Conv + L2 Norm""" + y, _ = conv_fused(x, head_dim=head_dim) + return y + +if __name__ == "__main__": + import torch.utils.benchmark as benchmark + + # Test configurations + B, T, D, W = 4, 2048, 2048, 4 + H = 16 + head_dim = D // H + + print("="*80) + print(f"Benchmarking Conv + L2 Norm: B={B}, T={T}, D={D}, W={W}, H={H}, head_dim={head_dim}") + print("="*80) + + dtype = torch.bfloat16 + + # Create input + x = torch.randn(B, T, D, device=device, dtype=dtype, requires_grad=True) + + # Separate Conv (no norm) + conv_separate = ShortConvolution( + hidden_size=D, + kernel_size=W, + bias=False, + activation='silu', + norm=None, + device=device, + dtype=dtype, + ) + + # Fused Conv + L2 Norm + conv_fused = ShortConvolution( + hidden_size=D, + kernel_size=W, + bias=False, + activation='silu', + norm='l2', + norm_eps=1e-5, + device=device, + dtype=dtype, + ) + + # Copy weights + conv_fused.weight.data.copy_(conv_separate.weight.data) + + # Benchmark Forward + print("\n" + "="*80) + print("Forward Pass") + print("="*80) + + t_sep_fwd = benchmark.Timer( + stmt="separate_conv_l2(x, conv, head_dim)", + globals={"separate_conv_l2": separate_conv_l2, "x": x, "conv": conv_separate, "head_dim": head_dim}, + ) + m_sep_fwd = t_sep_fwd.timeit(100) + print(f"Separate: {m_sep_fwd}") + + t_fused_fwd = benchmark.Timer( + stmt="fused_conv_l2(x, conv, head_dim)", + globals={"fused_conv_l2": fused_conv_l2, "x": x, "conv": conv_fused, "head_dim": head_dim}, + ) + m_fused_fwd = t_fused_fwd.timeit(100) + print(f"Fused: {m_fused_fwd}") + + # Benchmark Backward + print("\n" + "="*80) + print("Backward Pass") + print("="*80) + + # Pre-compute forward for backward benchmark + y_sep = separate_conv_l2(x, conv_separate, head_dim) + grad_sep = torch.randn_like(y_sep) + + def backward_sep(): + for xi in [x]: + if isinstance(xi, torch.Tensor): + xi.grad = None + y_sep.backward(grad_sep, retain_graph=True) + + t_sep_bwd = benchmark.Timer( + stmt="backward_sep()", + globals={"backward_sep": backward_sep}, + ) + m_sep_bwd = t_sep_bwd.timeit(100) + print(f"Separate: {m_sep_bwd}") + + y_fused = fused_conv_l2(x, conv_fused, head_dim) + grad_fused = torch.randn_like(y_fused) + + def backward_fused(): + for xi in [x]: + if isinstance(xi, torch.Tensor): + xi.grad = None + y_fused.backward(grad_fused, retain_graph=True) + + t_fused_bwd = benchmark.Timer( + stmt="backward_fused()", + globals={"backward_fused": backward_fused}, + ) + m_fused_bwd = t_fused_bwd.timeit(100) + print(f"Fused: {m_fused_bwd}") + + # Benchmark Combined + print("\n" + "="*80) + print("Forward + Backward Pass") + print("="*80) + + def combined_sep(): + for xi in [x]: + if isinstance(xi, torch.Tensor): + xi.grad = None + y = separate_conv_l2(x, conv_separate, head_dim) + y.backward(grad_sep, retain_graph=True) + + t_sep_combined = benchmark.Timer( + stmt="combined_sep()", + globals={"combined_sep": combined_sep}, + ) + m_sep_combined = t_sep_combined.timeit(100) + print(f"Separate: {m_sep_combined}") + + def combined_fused(): + for xi in [x]: + if isinstance(xi, torch.Tensor): + xi.grad = None + y = fused_conv_l2(x, conv_fused, head_dim) + y.backward(grad_fused, retain_graph=True) + + t_fused_combined = benchmark.Timer( + stmt="combined_fused()", + globals={"combined_fused": combined_fused}, + ) + m_fused_combined = t_fused_combined.timeit(100) + print(f"Fused: {m_fused_combined}") + + # Summary + time_sep_fwd = m_sep_fwd.median * 1000 + time_sep_bwd = m_sep_bwd.median * 1000 + time_sep_combined = m_sep_combined.median * 1000 + + time_fused_fwd = m_fused_fwd.median * 1000 + time_fused_bwd = m_fused_bwd.median * 1000 + time_fused_combined = m_fused_combined.median * 1000 + + print(f"\n{'='*80}") + print(f"{'Method':<35} {'Forward':<12} {'Backward':<12} {'Combined':<12} {'Speedup':<10}") + print("-"*80) + print(f"{'Separate (FLA)':<35} {time_sep_fwd:>10.3f}ms {time_sep_bwd:>10.3f}ms {time_sep_combined:>10.3f}ms {'1.00x':<10}") + print(f"{'Fused (Recompute)':<35} {time_fused_fwd:>10.3f}ms {time_fused_bwd:>10.3f}ms {time_fused_combined:>10.3f}ms {time_sep_combined/time_fused_combined:<10.2f}x") + + speedup_fwd = (time_sep_fwd / time_fused_fwd - 1) * 100 + speedup_bwd = (time_sep_bwd / time_fused_bwd - 1) * 100 + speedup_combined = (time_sep_combined / time_fused_combined - 1) * 100 + + print(f"\n{'='*80}") + print(f"Forward Speedup: {speedup_fwd:>+8.2f}%") + print(f"Backward Speedup: {speedup_bwd:>+8.2f}%") + print(f"Combined Speedup: {speedup_combined:>+8.2f}%") + print(f"\nMemory Saved: {B*T*D*2/1024/1024:.2f} MB per Conv layer (Y_act not stored)") + print(f"{'='*80}") diff --git a/fla/layers/delta_net.py b/fla/layers/delta_net.py index 993dd74d3..bd6ea58ca 100644 --- a/fla/layers/delta_net.py +++ b/fla/layers/delta_net.py @@ -87,6 +87,7 @@ def __init__( qk_activation: str = 'silu', qk_norm: str = 'l2', norm_eps: float = 1e-5, + fuse_norm: bool = True, **kwargs, ) -> DeltaNet: super().__init__() @@ -94,6 +95,7 @@ def __init__( self.mode = mode self.qk_activation = qk_activation self.qk_norm = qk_norm + self.fuse_norm = fuse_norm and (qk_norm == 'l2') assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] assert self.qk_norm in ['l2', 'sum'] @@ -136,12 +138,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu' if qk_activation == 'silu' else None, + norm='l2' if self.fuse_norm else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation='silu' if qk_activation == 'silu' else None, + norm='l2' if self.fuse_norm else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, @@ -200,12 +206,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_norm else None ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_norm else None ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -252,7 +260,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=(self.qk_norm == 'l2'), + use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm), ) elif mode == 'chunk': o, recurrent_state = chunk_delta_rule( @@ -263,7 +271,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=(self.qk_norm == 'l2'), + use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm), ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") diff --git a/fla/modules/convolution.py b/fla/modules/convolution.py index 301ac9040..57916347b 100644 --- a/fla/modules/convolution.py +++ b/fla/modules/convolution.py @@ -10,6 +10,7 @@ import triton.language as tl from einops import rearrange +from fla.ops.convolution.fused_short_conv import fused_short_conv from fla.ops.utils import prepare_chunk_indices, prepare_sequence_ids from fla.utils import IS_AMD, autotune_cache_kwargs, get_multiprocessor_count, input_guard @@ -836,6 +837,8 @@ def __init__( kernel_size: int, bias: bool = False, activation: str | None = 'silu', + norm: str | None = None, + norm_eps: float = 1e-5, backend: str | None = 'triton', device: torch.device | None = None, dtype: torch.dtype | None = None, @@ -854,11 +857,17 @@ def __init__( self.hidden_size = hidden_size self.activation = None + self.norm = norm + self.norm_eps = norm_eps if activation is not None: assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." self.activation = activation + if norm is not None: + assert norm == 'l2', f"Normalization `{norm}` not supported yet." + assert backend == 'triton', "Fused normalization only supported with Triton backend." + if 'use_fast_conv1d' in kwargs: warnings.warn( "The `use_fast_conv1d` parameter is deprecated and will be ignored. " @@ -906,6 +915,7 @@ def forward( output_final_state: bool = False, cu_seqlens: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, + head_dim: int | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -926,6 +936,8 @@ def forward( Shape: [B+1] chunk_indices (Optional[torch.LongTensor]): Chunk indices for variable-length sequences. Default: `None`. + head_dim (Optional[int]): + The head dimension for L2 normalization. Default: `None`. Returns: Tensor of shape `[B, T, D]`. @@ -946,6 +958,7 @@ def forward( cache=cache, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + head_dim=head_dim, ) return y, cache @@ -964,6 +977,24 @@ def forward( ) self.backend = 'triton' + if self.norm is not None: + if head_dim is None: + raise ValueError("`head_dim` must be provided when using fused normalization.") + return fused_short_conv( + x=x, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + residual=residual, + initial_state=cache, + output_final_state=output_final_state, + activation=self.activation, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_norm=True, + norm_eps=self.norm_eps, + head_dim=head_dim, + ) + return causal_conv1d( x=x, weight=rearrange(self.weight, "d 1 w -> d w"), @@ -985,6 +1016,7 @@ def step( cache: torch.Tensor, output_final_state: bool = False, cu_seqlens: torch.LongTensor | None = None, + head_dim: int | None = None, ): B, _, D, W = *x.shape, self.kernel_size[0] N = B if cu_seqlens is None else len(cu_seqlens) - 1 @@ -992,7 +1024,7 @@ def step( cache = x.new_zeros(N, D, W) # NOTE: we follow the fast mode that updates the cache in-place if self.backend == 'triton': - return causal_conv1d_update( + y, cache = causal_conv1d_update( x=x, cache=cache, residual=residual, @@ -1000,6 +1032,14 @@ def step( bias=self.bias, activation=self.activation, ) + if self.norm is not None: + if head_dim is None: + raise ValueError("`head_dim` must be provided when using fused normalization.") + y = rearrange(y, '... (h d) -> ... h d', d=head_dim) + norm = y.norm(p=2, dim=-1, keepdim=True) + y = y / (norm + self.norm_eps) + y = rearrange(y, '... h d -> ... (h d)') + return y, cache shape = x.shape x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1) @@ -1017,6 +1057,13 @@ def step( y = y.view(shape) if residual is not None: y.add_(residual) + if self.norm is not None: + if head_dim is None: + raise ValueError("`head_dim` must be provided when using fused normalization.") + y = rearrange(y, '... (h d) -> ... h d', d=head_dim) + norm = y.norm(p=2, dim=-1, keepdim=True) + y = y / (norm + self.norm_eps) + y = rearrange(y, '... h d -> ... (h d)') return y, cache @property diff --git a/fla/ops/convolution/__init__.py b/fla/ops/convolution/__init__.py new file mode 100644 index 000000000..e69de29bb From f1d9db7ff13c8e88dbc464c11da6584b4a3ccd5c Mon Sep 17 00:00:00 2001 From: Songlin Yang Date: Mon, 24 Nov 2025 10:39:31 +0000 Subject: [PATCH 3/3] Add fuse_conv_l2 flag to conv+l2 consumers --- fla/layers/comba.py | 12 ++++++++-- fla/layers/delta_net.py | 24 ++++++++++++------- fla/layers/gated_deltanet.py | 12 ++++++++-- fla/layers/gated_deltaproduct.py | 12 ++++++++-- fla/layers/kda.py | 12 ++++++++-- fla/layers/mesa_net.py | 15 +++++++++--- fla/layers/mom.py | 18 ++++++++++---- fla/models/comba/configuration_comba.py | 2 ++ fla/models/comba/modeling_comba.py | 1 + .../delta_net/configuration_delta_net.py | 2 ++ fla/models/delta_net/modeling_delta_net.py | 1 + .../configuration_gated_deltanet.py | 2 ++ .../gated_deltanet/modeling_gated_deltanet.py | 1 + .../configuration_gated_deltaproduct.py | 2 ++ .../modeling_gated_deltaproduct.py | 1 + fla/models/kda/configuration_kda.py | 2 ++ fla/models/kda/modeling_kda.py | 1 + fla/models/mesa_net/configuration_mesa_net.py | 2 ++ fla/models/mesa_net/modeling_mesa_net.py | 1 + fla/models/mom/configuration_mom.py | 2 ++ fla/models/mom/modeling_mom.py | 1 + fla/ops/convolution/__init__.py | 4 ++++ 22 files changed, 107 insertions(+), 23 deletions(-) diff --git a/fla/layers/comba.py b/fla/layers/comba.py index 0ccbbf23c..4edcc756b 100644 --- a/fla/layers/comba.py +++ b/fla/layers/comba.py @@ -91,6 +91,7 @@ def __init__( conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, + fuse_conv_l2: bool = True, **kwargs, ) -> Comba: super().__init__() @@ -106,6 +107,7 @@ def __init__( self.use_inner_decay = use_inner_decay self.conv_size = conv_size self.conv_bias = conv_bias + self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv self.head_dim = head_dim self.num_heads = num_heads @@ -179,12 +181,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, @@ -243,12 +249,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -291,7 +299,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) elif mode == 'fused_recurrent': o, recurrent_state = fused_recurrent_comba( @@ -304,7 +312,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") diff --git a/fla/layers/delta_net.py b/fla/layers/delta_net.py index bd6ea58ca..2eeece232 100644 --- a/fla/layers/delta_net.py +++ b/fla/layers/delta_net.py @@ -87,7 +87,8 @@ def __init__( qk_activation: str = 'silu', qk_norm: str = 'l2', norm_eps: float = 1e-5, - fuse_norm: bool = True, + fuse_conv_l2: bool = True, + fuse_norm: bool | None = None, **kwargs, ) -> DeltaNet: super().__init__() @@ -95,7 +96,14 @@ def __init__( self.mode = mode self.qk_activation = qk_activation self.qk_norm = qk_norm - self.fuse_norm = fuse_norm and (qk_norm == 'l2') + if fuse_norm is not None: + warnings.warn( + "`fuse_norm` is deprecated for DeltaNet; use `fuse_conv_l2` to control the fused " + "ShortConvolution + L2 kernel.", + stacklevel=2, + ) + fuse_conv_l2 = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 and use_short_conv and (qk_norm == 'l2') assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] assert self.qk_norm in ['l2', 'sum'] @@ -138,7 +146,7 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu' if qk_activation == 'silu' else None, - norm='l2' if self.fuse_norm else None, + norm='l2' if self.fuse_conv_l2 else None, norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( @@ -146,7 +154,7 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu' if qk_activation == 'silu' else None, - norm='l2' if self.fuse_norm else None, + norm='l2' if self.fuse_conv_l2 else None, norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( @@ -206,14 +214,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, - head_dim=self.head_k_dim if self.fuse_norm else None + head_dim=self.head_k_dim if self.fuse_conv_l2 else None ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, - head_dim=self.head_k_dim if self.fuse_norm else None + head_dim=self.head_k_dim if self.fuse_conv_l2 else None ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -260,7 +268,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm), + use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2), ) elif mode == 'chunk': o, recurrent_state = chunk_delta_rule( @@ -271,7 +279,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm), + use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2), ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") diff --git a/fla/layers/gated_deltanet.py b/fla/layers/gated_deltanet.py index 65d4ee660..437ea44a1 100644 --- a/fla/layers/gated_deltanet.py +++ b/fla/layers/gated_deltanet.py @@ -100,6 +100,7 @@ def __init__( conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, + fuse_conv_l2: bool = True, **kwargs, ) -> GatedDeltaNet: super().__init__() @@ -113,6 +114,7 @@ def __init__( self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias + self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv self.head_dim = head_dim self.num_heads = num_heads @@ -174,12 +176,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, @@ -239,12 +245,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -280,7 +288,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) elif mode == 'fused_recurrent': o, recurrent_state = fused_recurrent_gated_delta_rule( @@ -292,7 +300,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") diff --git a/fla/layers/gated_deltaproduct.py b/fla/layers/gated_deltaproduct.py index 53b096aa2..1085ebc3f 100644 --- a/fla/layers/gated_deltaproduct.py +++ b/fla/layers/gated_deltaproduct.py @@ -44,6 +44,7 @@ def __init__( use_forget_gate: bool = True, allow_neg_eigval: bool = True, num_householder: int = 2, + fuse_conv_l2: bool = True, **kwargs, ) -> GatedDeltaProduct: super().__init__() @@ -60,6 +61,7 @@ def __init__( self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias + self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv self.head_dim = head_dim self.num_heads = num_heads @@ -122,12 +124,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim * num_householder, kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim * num_householder, @@ -196,12 +202,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -243,7 +251,7 @@ def forward( output_final_state=use_cache, cu_seqlens=cu_seqlens, num_householder=self.num_householder, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) elif mode == 'fused_recurrent': @@ -264,7 +272,7 @@ def forward( initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) o = rearrange(o, '... (t n) h d -> ... t n h d', n=self.num_householder)[..., -1, :, :].contiguous() diff --git a/fla/layers/kda.py b/fla/layers/kda.py index 3533a1a86..ced65b32b 100644 --- a/fla/layers/kda.py +++ b/fla/layers/kda.py @@ -71,6 +71,7 @@ def __init__( conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, + fuse_conv_l2: bool = True, **kwargs, ) -> KimiDeltaAttention: super().__init__() @@ -83,6 +84,7 @@ def __init__( self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias + self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv self.head_dim = head_dim self.num_heads = num_heads @@ -122,12 +124,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, @@ -194,12 +200,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), @@ -237,7 +245,7 @@ def forward( beta=beta, initial_state=recurrent_state, output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, cu_seqlens=cu_seqlens, ) elif mode == 'fused_recurrent': @@ -249,7 +257,7 @@ def forward( beta=beta, initial_state=recurrent_state, output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, cu_seqlens=cu_seqlens, ) else: diff --git a/fla/layers/mesa_net.py b/fla/layers/mesa_net.py index 203e65856..d4cd0007f 100644 --- a/fla/layers/mesa_net.py +++ b/fla/layers/mesa_net.py @@ -66,6 +66,7 @@ def __init__( lambda_lower_bound: float = 0.25, max_cg_step_training: int = 30, max_cg_step_decoding: int = 30, + fuse_conv_l2: bool = True, **kwargs, ) -> MesaNet: super().__init__() @@ -86,6 +87,7 @@ def __init__( self.lambda_lower_bound = lambda_lower_bound self.max_cg_step_training = max_cg_step_training self.max_cg_step_decoding = max_cg_step_decoding + self.fuse_conv_l2 = fuse_conv_l2 self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) @@ -106,12 +108,16 @@ def __init__( kernel_size=conv_size, bias=self.conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=self.conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) if use_output_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) @@ -154,12 +160,14 @@ def forward( cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_k_dim if self.fuse_conv_l2 else None, ) v = self.v_proj(hidden_states) @@ -184,13 +192,14 @@ def forward( lamb=lamb, output_final_state=use_cache, max_CG_iteration=self.max_cg_step_training, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, cu_seqlens=cu_seqlens, ) # decoding else: - q = l2_norm(q) - k = l2_norm(k) + if not self.fuse_conv_l2: + q = l2_norm(q) + k = l2_norm(k) o, h_kk, h_kv = mesa_net_decoding_one_step( q=q.squeeze(0), k=k.squeeze(0), diff --git a/fla/layers/mom.py b/fla/layers/mom.py index d7042036d..6f6f2f635 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -289,6 +289,7 @@ def __init__( conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, + fuse_conv_l2: bool = True, num_memories: int = 8, topk: int = 2, capacity: float = 1.0, @@ -312,6 +313,7 @@ def __init__( self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias + self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv self.head_dim = head_dim self.num_heads = num_heads @@ -381,12 +383,16 @@ def __init__( kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, bias=conv_bias, activation='silu', + norm='l2' if self.fuse_conv_l2 else None, + norm_eps=norm_eps, ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, @@ -513,6 +519,7 @@ def forward( cache=conv_q, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, + head_dim=self.head_qk_dim if self.fuse_conv_l2 else None, ) conv_state_q[0] = self.handle_recurrent_state( conv_state_q[0], @@ -533,6 +540,7 @@ def forward( cache=conv_k, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, + head_dim=self.head_qk_dim if self.fuse_conv_l2 else None, ) conv_state_k[0] = self.handle_recurrent_state( conv_state_k[0], @@ -580,7 +588,7 @@ def forward( beta=cu_beta, initial_state=recurrent_state[0], output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, cu_seqlens=cu_seqlens, ) recurrent_state[0] = self.handle_recurrent_state( @@ -605,7 +613,7 @@ def forward( beta=cu_beta, initial_state=memories, output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, cu_seqlens=cu_seqlens, ) recurrent_state[0] = self.handle_recurrent_state( @@ -684,12 +692,14 @@ def shared_o( cache=conv_state_q[1], output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_qk_dim if self.fuse_conv_l2 else None, ) k, conv_state_k[1] = self.k_conv1d( x=self.shared_k(hidden_states), cache=conv_state_k[1], output_final_state=use_cache, cu_seqlens=cu_seqlens, + head_dim=self.head_qk_dim if self.fuse_conv_l2 else None, ) v, conv_state_v[1] = self.v_conv1d( x=self.shared_v(hidden_states), @@ -716,7 +726,7 @@ def shared_o( initial_state=recurrent_state[-1], output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) elif mode == 'fused_recurrent': o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( @@ -728,7 +738,7 @@ def shared_o( initial_state=recurrent_state[-1], output_final_state=use_cache, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=not self.fuse_conv_l2, ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") diff --git a/fla/models/comba/configuration_comba.py b/fla/models/comba/configuration_comba.py index 6f17f65e0..1bf8c6658 100644 --- a/fla/models/comba/configuration_comba.py +++ b/fla/models/comba/configuration_comba.py @@ -36,6 +36,7 @@ def __init__( tie_word_embeddings: bool = False, initializer_range: float = 0.02, fuse_norm: bool = True, + fuse_conv_l2: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, fuse_linear_cross_entropy: bool = False, @@ -67,6 +68,7 @@ def __init__( self.initializer_range = initializer_range self.fuse_norm = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy self.fuse_linear_cross_entropy = fuse_linear_cross_entropy diff --git a/fla/models/comba/modeling_comba.py b/fla/models/comba/modeling_comba.py index 62443caa4..b9c0ae63a 100644 --- a/fla/models/comba/modeling_comba.py +++ b/fla/models/comba/modeling_comba.py @@ -65,6 +65,7 @@ def __init__(self, config: CombaConfig, layer_idx: int): conv_size=config.conv_size, norm_eps=config.norm_eps, layer_idx=layer_idx, + fuse_conv_l2=config.fuse_conv_l2, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = CombaMLP( diff --git a/fla/models/delta_net/configuration_delta_net.py b/fla/models/delta_net/configuration_delta_net.py index e432338b6..8918431b3 100644 --- a/fla/models/delta_net/configuration_delta_net.py +++ b/fla/models/delta_net/configuration_delta_net.py @@ -37,6 +37,7 @@ def __init__( tie_word_embeddings: bool = False, initializer_range: float = 0.02, fuse_norm: bool = True, + fuse_conv_l2: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, fuse_linear_cross_entropy: bool = False, @@ -67,6 +68,7 @@ def __init__( self.use_cache = use_cache self.initializer_range = initializer_range self.fuse_norm = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy self.fuse_linear_cross_entropy = fuse_linear_cross_entropy diff --git a/fla/models/delta_net/modeling_delta_net.py b/fla/models/delta_net/modeling_delta_net.py index d1290caf2..838f7f58d 100644 --- a/fla/models/delta_net/modeling_delta_net.py +++ b/fla/models/delta_net/modeling_delta_net.py @@ -66,6 +66,7 @@ def __init__(self, config: DeltaNetConfig, layer_idx: int): qk_norm=config.qk_norm, qk_activation=config.qk_activation, norm_eps=config.norm_eps, + fuse_conv_l2=config.fuse_conv_l2, layer_idx=layer_idx, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) diff --git a/fla/models/gated_deltanet/configuration_gated_deltanet.py b/fla/models/gated_deltanet/configuration_gated_deltanet.py index 653f80191..b751a1a77 100644 --- a/fla/models/gated_deltanet/configuration_gated_deltanet.py +++ b/fla/models/gated_deltanet/configuration_gated_deltanet.py @@ -34,6 +34,7 @@ def __init__( tie_word_embeddings: bool = False, initializer_range: float = 0.02, fuse_norm: bool = True, + fuse_conv_l2: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, fuse_linear_cross_entropy: bool = False, @@ -62,6 +63,7 @@ def __init__( self.initializer_range = initializer_range self.fuse_norm = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy self.fuse_linear_cross_entropy = fuse_linear_cross_entropy diff --git a/fla/models/gated_deltanet/modeling_gated_deltanet.py b/fla/models/gated_deltanet/modeling_gated_deltanet.py index 0d3d2634a..19e5de264 100644 --- a/fla/models/gated_deltanet/modeling_gated_deltanet.py +++ b/fla/models/gated_deltanet/modeling_gated_deltanet.py @@ -66,6 +66,7 @@ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int): conv_size=config.conv_size, norm_eps=config.norm_eps, layer_idx=layer_idx, + fuse_conv_l2=config.fuse_conv_l2, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = GatedDeltaNetMLP( diff --git a/fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py b/fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py index fe40abfb8..aca3d94c1 100644 --- a/fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py +++ b/fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py @@ -32,6 +32,7 @@ def __init__( tie_word_embeddings: bool = False, initializer_range: float = 0.02, fuse_norm: bool = True, + fuse_conv_l2: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, fuse_linear_cross_entropy: bool = False, @@ -62,6 +63,7 @@ def __init__( self.initializer_range = initializer_range self.fuse_norm = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy self.fuse_linear_cross_entropy = fuse_linear_cross_entropy diff --git a/fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py b/fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py index 2e66a35f4..966caf399 100644 --- a/fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +++ b/fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py @@ -67,6 +67,7 @@ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int): allow_neg_eigval=config.allow_neg_eigval, num_householder=config.num_householder, layer_idx=layer_idx, + fuse_conv_l2=config.fuse_conv_l2, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = GatedDeltaProductMLP( diff --git a/fla/models/kda/configuration_kda.py b/fla/models/kda/configuration_kda.py index 89b162925..429ce1257 100644 --- a/fla/models/kda/configuration_kda.py +++ b/fla/models/kda/configuration_kda.py @@ -32,6 +32,7 @@ def __init__( tie_word_embeddings: bool = False, initializer_range: float = 0.02, fuse_norm: bool = True, + fuse_conv_l2: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, use_l2warp: bool = False, @@ -58,6 +59,7 @@ def __init__( self.initializer_range = initializer_range self.fuse_norm = fuse_norm + self.fuse_conv_l2 = fuse_conv_l2 self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy self.use_l2warp = use_l2warp diff --git a/fla/models/kda/modeling_kda.py b/fla/models/kda/modeling_kda.py index 6def9567e..a5e70a7a3 100644 --- a/fla/models/kda/modeling_kda.py +++ b/fla/models/kda/modeling_kda.py @@ -65,6 +65,7 @@ def __init__(self, config: KDAConfig, layer_idx: int): conv_size=config.conv_size, norm_eps=config.norm_eps, layer_idx=layer_idx, + fuse_conv_l2=config.fuse_conv_l2, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = KDAMLP( diff --git a/fla/models/mesa_net/configuration_mesa_net.py b/fla/models/mesa_net/configuration_mesa_net.py index 0c9549fd6..81c7a1886 100644 --- a/fla/models/mesa_net/configuration_mesa_net.py +++ b/fla/models/mesa_net/configuration_mesa_net.py @@ -39,6 +39,7 @@ def __init__( vocab_size: int = 32000, max_cg_step_training: int = 30, max_cg_step_decoding: int = 30, + fuse_conv_l2: bool = True, **kwargs, ): self.attn_mode = attn_mode @@ -68,6 +69,7 @@ def __init__( self.vocab_size = vocab_size self.max_cg_step_training = max_cg_step_training self.max_cg_step_decoding = max_cg_step_decoding + self.fuse_conv_l2 = fuse_conv_l2 if fuse_cross_entropy and fuse_linear_cross_entropy: raise ValueError( diff --git a/fla/models/mesa_net/modeling_mesa_net.py b/fla/models/mesa_net/modeling_mesa_net.py index 18fe23ebe..d47563b95 100644 --- a/fla/models/mesa_net/modeling_mesa_net.py +++ b/fla/models/mesa_net/modeling_mesa_net.py @@ -66,6 +66,7 @@ def __init__(self, config: MesaNetConfig, layer_idx: int): layer_idx=layer_idx, max_cg_step_training=config.max_cg_step_training, max_cg_step_decoding=config.max_cg_step_decoding, + fuse_conv_l2=config.fuse_conv_l2, ) self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) self.mlp = MesaNetMLP( diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py index e16931d9f..d74f5cd82 100644 --- a/fla/models/mom/configuration_mom.py +++ b/fla/models/mom/configuration_mom.py @@ -41,6 +41,7 @@ def __init__( fuse_norm: bool = True, fuse_swiglu: bool = True, fuse_cross_entropy: bool = True, + fuse_conv_l2: bool = True, vocab_size: int = 32000, **kwargs, ): @@ -75,6 +76,7 @@ def __init__( self.fuse_norm = fuse_norm self.fuse_swiglu = fuse_swiglu self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_conv_l2 = fuse_conv_l2 self.vocab_size = vocab_size if self.mom_backend not in ['gated_deltanet']: diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index cac1e99a7..2ce252fe8 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -147,6 +147,7 @@ def __init__(self, config: MomConfig, layer_idx: int): capacity=config.capacity, shared_mem=config.shared_mem, single_kv_proj=config.single_kv_proj, + fuse_conv_l2=config.fuse_conv_l2, ) else: raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") diff --git a/fla/ops/convolution/__init__.py b/fla/ops/convolution/__init__.py index e69de29bb..57810e08b 100644 --- a/fla/ops/convolution/__init__.py +++ b/fla/ops/convolution/__init__.py @@ -0,0 +1,4 @@ +from .fused_short_conv import fused_short_conv + +__all__ = ['fused_short_conv'] +