Skip to content

Commit f1d9db7

Browse files
committed
Add fuse_conv_l2 flag to conv+l2 consumers
1 parent 562a840 commit f1d9db7

22 files changed

+107
-23
lines changed

fla/layers/comba.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
conv_bias: bool = False,
9292
layer_idx: int = None,
9393
norm_eps: float = 1e-5,
94+
fuse_conv_l2: bool = True,
9495
**kwargs,
9596
) -> Comba:
9697
super().__init__()
@@ -106,6 +107,7 @@ def __init__(
106107
self.use_inner_decay = use_inner_decay
107108
self.conv_size = conv_size
108109
self.conv_bias = conv_bias
110+
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv
109111

110112
self.head_dim = head_dim
111113
self.num_heads = num_heads
@@ -179,12 +181,16 @@ def __init__(
179181
kernel_size=conv_size,
180182
bias=conv_bias,
181183
activation='silu',
184+
norm='l2' if self.fuse_conv_l2 else None,
185+
norm_eps=norm_eps,
182186
)
183187
self.k_conv1d = ShortConvolution(
184188
hidden_size=self.key_dim,
185189
kernel_size=conv_size,
186190
bias=conv_bias,
187191
activation='silu',
192+
norm='l2' if self.fuse_conv_l2 else None,
193+
norm_eps=norm_eps,
188194
)
189195
self.v_conv1d = ShortConvolution(
190196
hidden_size=self.value_dim,
@@ -243,12 +249,14 @@ def forward(
243249
cache=conv_state_q,
244250
output_final_state=use_cache,
245251
cu_seqlens=cu_seqlens,
252+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
246253
)
247254
k, conv_state_k = self.k_conv1d(
248255
x=self.k_proj(hidden_states),
249256
cache=conv_state_k,
250257
output_final_state=use_cache,
251258
cu_seqlens=cu_seqlens,
259+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
252260
)
253261
v, conv_state_v = self.v_conv1d(
254262
x=self.v_proj(hidden_states),
@@ -291,7 +299,7 @@ def forward(
291299
initial_state=recurrent_state,
292300
output_final_state=use_cache,
293301
cu_seqlens=cu_seqlens,
294-
use_qk_l2norm_in_kernel=True,
302+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
295303
)
296304
elif mode == 'fused_recurrent':
297305
o, recurrent_state = fused_recurrent_comba(
@@ -304,7 +312,7 @@ def forward(
304312
initial_state=recurrent_state,
305313
output_final_state=use_cache,
306314
cu_seqlens=cu_seqlens,
307-
use_qk_l2norm_in_kernel=True,
315+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
308316
)
309317
else:
310318
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/delta_net.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,23 @@ def __init__(
8787
qk_activation: str = 'silu',
8888
qk_norm: str = 'l2',
8989
norm_eps: float = 1e-5,
90-
fuse_norm: bool = True,
90+
fuse_conv_l2: bool = True,
91+
fuse_norm: bool | None = None,
9192
**kwargs,
9293
) -> DeltaNet:
9394
super().__init__()
9495

9596
self.mode = mode
9697
self.qk_activation = qk_activation
9798
self.qk_norm = qk_norm
98-
self.fuse_norm = fuse_norm and (qk_norm == 'l2')
99+
if fuse_norm is not None:
100+
warnings.warn(
101+
"`fuse_norm` is deprecated for DeltaNet; use `fuse_conv_l2` to control the fused "
102+
"ShortConvolution + L2 kernel.",
103+
stacklevel=2,
104+
)
105+
fuse_conv_l2 = fuse_norm
106+
self.fuse_conv_l2 = fuse_conv_l2 and use_short_conv and (qk_norm == 'l2')
99107

100108
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
101109
assert self.qk_norm in ['l2', 'sum']
@@ -138,15 +146,15 @@ def __init__(
138146
kernel_size=conv_size,
139147
bias=conv_bias,
140148
activation='silu' if qk_activation == 'silu' else None,
141-
norm='l2' if self.fuse_norm else None,
149+
norm='l2' if self.fuse_conv_l2 else None,
142150
norm_eps=norm_eps,
143151
)
144152
self.k_conv1d = ShortConvolution(
145153
hidden_size=self.key_dim,
146154
kernel_size=conv_size,
147155
bias=conv_bias,
148156
activation='silu' if qk_activation == 'silu' else None,
149-
norm='l2' if self.fuse_norm else None,
157+
norm='l2' if self.fuse_conv_l2 else None,
150158
norm_eps=norm_eps,
151159
)
152160
self.v_conv1d = ShortConvolution(
@@ -206,14 +214,14 @@ def forward(
206214
cache=conv_state_q,
207215
output_final_state=use_cache,
208216
cu_seqlens=cu_seqlens,
209-
head_dim=self.head_k_dim if self.fuse_norm else None
217+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None
210218
)
211219
k, conv_state_k = self.k_conv1d(
212220
x=self.k_proj(hidden_states),
213221
cache=conv_state_k,
214222
output_final_state=use_cache,
215223
cu_seqlens=cu_seqlens,
216-
head_dim=self.head_k_dim if self.fuse_norm else None
224+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None
217225
)
218226
v, conv_state_v = self.v_conv1d(
219227
x=self.v_proj(hidden_states),
@@ -260,7 +268,7 @@ def forward(
260268
initial_state=recurrent_state,
261269
output_final_state=use_cache,
262270
cu_seqlens=cu_seqlens,
263-
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm),
271+
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2),
264272
)
265273
elif mode == 'chunk':
266274
o, recurrent_state = chunk_delta_rule(
@@ -271,7 +279,7 @@ def forward(
271279
initial_state=recurrent_state,
272280
output_final_state=use_cache,
273281
cu_seqlens=cu_seqlens,
274-
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_norm),
282+
use_qk_l2norm_in_kernel=(self.qk_norm == 'l2' and not self.fuse_conv_l2),
275283
)
276284
else:
277285
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/gated_deltanet.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
conv_bias: bool = False,
101101
layer_idx: int = None,
102102
norm_eps: float = 1e-5,
103+
fuse_conv_l2: bool = True,
103104
**kwargs,
104105
) -> GatedDeltaNet:
105106
super().__init__()
@@ -113,6 +114,7 @@ def __init__(
113114
self.use_short_conv = use_short_conv
114115
self.conv_size = conv_size
115116
self.conv_bias = conv_bias
117+
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv
116118

117119
self.head_dim = head_dim
118120
self.num_heads = num_heads
@@ -174,12 +176,16 @@ def __init__(
174176
kernel_size=conv_size,
175177
bias=conv_bias,
176178
activation='silu',
179+
norm='l2' if self.fuse_conv_l2 else None,
180+
norm_eps=norm_eps,
177181
)
178182
self.k_conv1d = ShortConvolution(
179183
hidden_size=self.key_dim,
180184
kernel_size=conv_size,
181185
bias=conv_bias,
182186
activation='silu',
187+
norm='l2' if self.fuse_conv_l2 else None,
188+
norm_eps=norm_eps,
183189
)
184190
self.v_conv1d = ShortConvolution(
185191
hidden_size=self.value_dim,
@@ -239,12 +245,14 @@ def forward(
239245
cache=conv_state_q,
240246
output_final_state=use_cache,
241247
cu_seqlens=cu_seqlens,
248+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
242249
)
243250
k, conv_state_k = self.k_conv1d(
244251
x=self.k_proj(hidden_states),
245252
cache=conv_state_k,
246253
output_final_state=use_cache,
247254
cu_seqlens=cu_seqlens,
255+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
248256
)
249257
v, conv_state_v = self.v_conv1d(
250258
x=self.v_proj(hidden_states),
@@ -280,7 +288,7 @@ def forward(
280288
initial_state=recurrent_state,
281289
output_final_state=use_cache,
282290
cu_seqlens=cu_seqlens,
283-
use_qk_l2norm_in_kernel=True,
291+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
284292
)
285293
elif mode == 'fused_recurrent':
286294
o, recurrent_state = fused_recurrent_gated_delta_rule(
@@ -292,7 +300,7 @@ def forward(
292300
initial_state=recurrent_state,
293301
output_final_state=use_cache,
294302
cu_seqlens=cu_seqlens,
295-
use_qk_l2norm_in_kernel=True,
303+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
296304
)
297305
else:
298306
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/gated_deltaproduct.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
use_forget_gate: bool = True,
4545
allow_neg_eigval: bool = True,
4646
num_householder: int = 2,
47+
fuse_conv_l2: bool = True,
4748
**kwargs,
4849
) -> GatedDeltaProduct:
4950
super().__init__()
@@ -60,6 +61,7 @@ def __init__(
6061
self.use_short_conv = use_short_conv
6162
self.conv_size = conv_size
6263
self.conv_bias = conv_bias
64+
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv
6365

6466
self.head_dim = head_dim
6567
self.num_heads = num_heads
@@ -122,12 +124,16 @@ def __init__(
122124
kernel_size=conv_size,
123125
bias=conv_bias,
124126
activation='silu',
127+
norm='l2' if self.fuse_conv_l2 else None,
128+
norm_eps=norm_eps,
125129
)
126130
self.k_conv1d = ShortConvolution(
127131
hidden_size=self.key_dim * num_householder,
128132
kernel_size=conv_size,
129133
bias=conv_bias,
130134
activation='silu',
135+
norm='l2' if self.fuse_conv_l2 else None,
136+
norm_eps=norm_eps,
131137
)
132138
self.v_conv1d = ShortConvolution(
133139
hidden_size=self.value_dim * num_householder,
@@ -196,12 +202,14 @@ def forward(
196202
cache=conv_state_q,
197203
output_final_state=use_cache,
198204
cu_seqlens=cu_seqlens,
205+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
199206
)
200207
k, conv_state_k = self.k_conv1d(
201208
x=self.k_proj(hidden_states),
202209
cache=conv_state_k,
203210
output_final_state=use_cache,
204211
cu_seqlens=cu_seqlens,
212+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
205213
)
206214
v, conv_state_v = self.v_conv1d(
207215
x=self.v_proj(hidden_states),
@@ -243,7 +251,7 @@ def forward(
243251
output_final_state=use_cache,
244252
cu_seqlens=cu_seqlens,
245253
num_householder=self.num_householder,
246-
use_qk_l2norm_in_kernel=True,
254+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
247255
)
248256

249257
elif mode == 'fused_recurrent':
@@ -264,7 +272,7 @@ def forward(
264272
initial_state=recurrent_state,
265273
output_final_state=use_cache,
266274
cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None,
267-
use_qk_l2norm_in_kernel=True,
275+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
268276
)
269277
o = rearrange(o, '... (t n) h d -> ... t n h d', n=self.num_householder)[..., -1, :, :].contiguous()
270278

fla/layers/kda.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
conv_bias: bool = False,
7272
layer_idx: int = None,
7373
norm_eps: float = 1e-5,
74+
fuse_conv_l2: bool = True,
7475
**kwargs,
7576
) -> KimiDeltaAttention:
7677
super().__init__()
@@ -83,6 +84,7 @@ def __init__(
8384
self.use_short_conv = use_short_conv
8485
self.conv_size = conv_size
8586
self.conv_bias = conv_bias
87+
self.fuse_conv_l2 = fuse_conv_l2 and self.use_short_conv
8688

8789
self.head_dim = head_dim
8890
self.num_heads = num_heads
@@ -122,12 +124,16 @@ def __init__(
122124
kernel_size=conv_size,
123125
bias=conv_bias,
124126
activation='silu',
127+
norm='l2' if self.fuse_conv_l2 else None,
128+
norm_eps=norm_eps,
125129
)
126130
self.k_conv1d = ShortConvolution(
127131
hidden_size=self.key_dim,
128132
kernel_size=conv_size,
129133
bias=conv_bias,
130134
activation='silu',
135+
norm='l2' if self.fuse_conv_l2 else None,
136+
norm_eps=norm_eps,
131137
)
132138
self.v_conv1d = ShortConvolution(
133139
hidden_size=self.value_dim,
@@ -194,12 +200,14 @@ def forward(
194200
cache=conv_state_q,
195201
output_final_state=use_cache,
196202
cu_seqlens=cu_seqlens,
203+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
197204
)
198205
k, conv_state_k = self.k_conv1d(
199206
x=self.k_proj(hidden_states),
200207
cache=conv_state_k,
201208
output_final_state=use_cache,
202209
cu_seqlens=cu_seqlens,
210+
head_dim=self.head_k_dim if self.fuse_conv_l2 else None,
203211
)
204212
v, conv_state_v = self.v_conv1d(
205213
x=self.v_proj(hidden_states),
@@ -237,7 +245,7 @@ def forward(
237245
beta=beta,
238246
initial_state=recurrent_state,
239247
output_final_state=use_cache,
240-
use_qk_l2norm_in_kernel=True,
248+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
241249
cu_seqlens=cu_seqlens,
242250
)
243251
elif mode == 'fused_recurrent':
@@ -249,7 +257,7 @@ def forward(
249257
beta=beta,
250258
initial_state=recurrent_state,
251259
output_final_state=use_cache,
252-
use_qk_l2norm_in_kernel=True,
260+
use_qk_l2norm_in_kernel=not self.fuse_conv_l2,
253261
cu_seqlens=cu_seqlens,
254262
)
255263
else:

0 commit comments

Comments
 (0)