Skip to content

Commit f8fa21e

Browse files
authored
adding variable length attention to llama3 8b (#2000)
**Summary** This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace `use_flex_attn` with `attn_type` (either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the attention module calls a compiled `varlen_attn` defined [here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py). **Testing** Ran loss and performance tests against flex attention. Loss is on par. <img width="947" height="505" alt="Screenshot 2025-11-19 at 3 24 26 PM" src="https://github.com/user-attachments/assets/d85dfc09-4f5e-4f82-abc9-49b870b34990" /> Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into `flash_attention_forward`/`flash_attention_backward` today). | | Varlen | Flex | | :---: | :------ | :---: | | Forward | 774us 357ns | 722us 317ns | | Backward | 1ms 955us 916ns | 1ms 558us 747ns |
1 parent 58fa181 commit f8fa21e

File tree

26 files changed

+304
-120
lines changed

26 files changed

+304
-120
lines changed

tests/integration_tests/features.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,18 @@ def build_features_test_list() -> list[OverrideDefinitions]:
346346
"fsdp+flex_attn+per_op_sac",
347347
ngpu=4,
348348
),
349+
OverrideDefinitions(
350+
[
351+
[
352+
"--parallelism.data_parallel_shard_degree=4",
353+
"--activation_checkpoint.mode='full'",
354+
"--model.flavor=debugmodel_varlen_attn",
355+
]
356+
],
357+
"FSDP+VARLEN_ATTN",
358+
"fsdp+varlen_attn",
359+
ngpu=4,
360+
),
349361
OverrideDefinitions(
350362
[
351363
[

torchtitan/experiments/forge/example_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def forward_backward_step(
161161
inputs = input_dict["input"]
162162
extra_kwargs = {}
163163

164-
if getattr(self.model_args, "use_flex_attn", False):
164+
if getattr(self.model_args, "attn_type", "sdpa") == "flex":
165165
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
166166
input_batch=inputs,
167167
tokenizer=self.tokenizer,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def parallelize_gptoss(
6262
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6363
"""
6464

65-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
66-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
67-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
68-
6965
if parallel_dims.tp_enabled:
7066
if (
7167
job_config.parallelism.enable_async_tensor_parallel
@@ -111,6 +107,8 @@ def parallelize_gptoss(
111107
job_config.compile.enable and "model" in job_config.compile.components
112108
)
113109

110+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
111+
use_flex_attn = attn_type == "flex"
114112
if job_config.activation_checkpoint.mode != "none":
115113
apply_ac(
116114
model,

torchtitan/experiments/gpt_oss/model/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs):
3939
n_kv_heads (int): Number of key-value heads.
4040
sliding_window_size (int): Size of the sliding attention window.
4141
attn_mask_type (str): Type of basic attention mask.
42-
use_flex_attn (bool): Whether to use FlexAttention. Only supports True.
42+
attn_type (bool): Attention type, only supports Flex.
4343
original_seq_len (int): Original sequence length.
4444
rope_theta (float): Base for rotary positional encoding.
4545
rope_factor (float): Scaling factor for extended sequence lengths.
@@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs):
6464
n_kv_heads: int = 8
6565
sliding_window_size: int = 128
6666
attn_mask_type: str = "causal"
67-
use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention
67+
attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention
6868
# yarn
6969
original_seq_len: int = 4096
7070
rope_theta: float = 150000.0

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def parallelize_deepseekv3(
6767

6868
if (
6969
job_config.parallelism.context_parallel_degree > 1
70-
and model.model_args.use_flex_attn
70+
and model.model_args.attn_type != "sdpa"
7171
):
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
72+
raise NotImplementedError("CP support is only supported for SDPA.")
7373

7474
if parallel_dims.tp_enabled:
7575
enable_float8_linear = "float8" in job_config.model.converters
@@ -85,13 +85,11 @@ def parallelize_deepseekv3(
8585
"Currently, float8 tensorwise TP is not tested for deepseekv3"
8686
)
8787

88-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
8988
apply_non_moe_tp(
9089
model,
9190
world_mesh["tp"],
9291
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9392
enable_float8_tensorwise_tp=False,
94-
use_flex_attn=use_flex_attn,
9593
)
9694
maybe_enable_async_tp(job_config, world_mesh["tp"])
9795

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def parallelize_llama(
102102
maybe_enable_async_tp(job_config, tp_mesh)
103103

104104
if job_config.activation_checkpoint.mode != "none":
105-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
105+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
106+
use_flex_attn = attn_type == "flex"
106107
model_compile_enabled = (
107108
job_config.compile.enable and "model" in job_config.compile.components
108109
)

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,17 @@ def parallelize_vlm(
4848
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
4949
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
5050
"""
51-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
52-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
53-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
51+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
52+
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
53+
raise NotImplementedError("CP support is only supported for SDPA.")
5454

5555
if parallel_dims.tp_enabled:
5656
raise NotImplementedError("TP support for VLM training is still in progress.")
5757

5858
model_compile_enabled = (
5959
job_config.compile.enable and "model" in job_config.compile.components
6060
)
61+
use_flex_attn = attn_type == "flex"
6162
if job_config.activation_checkpoint.mode != "none":
6263
apply_ac(
6364
model,

torchtitan/experiments/vlm/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Siglip2ModelArgs:
5353
spatial_merge_size: int = 1
5454

5555
layer_norm_eps: float = 1e-6
56-
use_flex_attn: bool = True
56+
attn_type: str = "flex"
5757
attn_mask_type: str = "causal"
5858

5959

torchtitan/models/attention.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import functools
1010
from collections.abc import Callable
11-
from typing import ClassVar
11+
from typing import ClassVar, NamedTuple
1212

1313
import torch
1414
import torch.nn.functional as F
@@ -20,10 +20,14 @@
2020
flex_attention,
2121
)
2222

23+
from torch.nn.attention.varlen import varlen_attn
24+
2325

2426
__all__ = [
2527
"FlexAttentionWrapper",
2628
"ScaledDotProductAttentionWrapper",
29+
"VarlenAttentionWrapper",
30+
"VarlenMetadata",
2731
"get_causal_mask_mod",
2832
"get_document_mask_mod",
2933
"get_sliding_window_mask_mod",
@@ -32,6 +36,53 @@
3236
]
3337

3438

39+
class VarlenMetadata(NamedTuple):
40+
"""
41+
Cumulative sequence positions for queries and keys/values.
42+
43+
"""
44+
45+
cu_seq_q: torch.Tensor
46+
cu_seq_k: torch.Tensor
47+
max_q: int
48+
max_k: int
49+
50+
51+
class VarlenAttentionWrapper(torch.nn.Module):
52+
_compiled_varlen_attn: ClassVar[Callable] = torch.compile(
53+
varlen_attn, mode="max-autotune-no-cudagraphs"
54+
)
55+
56+
def forward(
57+
self,
58+
xq: torch.Tensor,
59+
xk: torch.Tensor,
60+
xv: torch.Tensor,
61+
head_dim: torch.Tensor,
62+
attention_masks: VarlenMetadata,
63+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
64+
cu_seq_q = attention_masks.cu_seq_q
65+
cu_seq_k = attention_masks.cu_seq_k
66+
max_q = attention_masks.max_q
67+
max_k = attention_masks.max_k
68+
69+
n_local_heads = xq.shape[1]
70+
xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
71+
xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
72+
xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
73+
74+
return VarlenAttentionWrapper._compiled_varlen_attn(
75+
xq_packed,
76+
xk_packed,
77+
xv_packed,
78+
cu_seq_q,
79+
cu_seq_k,
80+
max_q,
81+
max_k,
82+
is_causal=True,
83+
)
84+
85+
3586
class FlexAttentionWrapper(torch.nn.Module):
3687
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.
3788
@@ -66,7 +117,6 @@ def forward(
66117
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
67118
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
68119
# to convert `lse` to be DTensor.
69-
70120
return FlexAttentionWrapper._compiled_flex_attn(
71121
q,
72122
k,
@@ -226,3 +276,60 @@ def create_attention_mask(*args, **kwargs):
226276
arguments.
227277
"""
228278
return _compiled_create_block_mask(*args, **kwargs)
279+
280+
281+
def create_varlen_metadata_for_document(
282+
input_batch: torch.Tensor, eos_id: int
283+
) -> VarlenMetadata:
284+
"""
285+
Creates cumulative sequence length indices needed for variable length attention
286+
287+
Args:
288+
input_batch
289+
eos_id: the EOS id marker
290+
291+
Returns:
292+
VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len
293+
"""
294+
batch_size, seq_len = input_batch.shape
295+
device = input_batch.device
296+
cu_seqlens_list, all_seq_lengths = [], []
297+
offset = 0
298+
max_seqlen = 0
299+
300+
for b in range(batch_size):
301+
tokens = input_batch[b]
302+
eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)
303+
sample_cu_seqlens = torch.cat(
304+
[
305+
torch.tensor([0], dtype=torch.int32, device=device),
306+
eos_positions + 1,
307+
torch.tensor([seq_len], dtype=torch.int32, device=device),
308+
]
309+
)
310+
sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)
311+
312+
seq_lengths = torch.diff(sample_cu_seqlens)
313+
all_seq_lengths.append(seq_lengths)
314+
315+
cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
316+
cu_seqlens_list.append(cu_seqlens_adjusted)
317+
318+
offset += seq_len
319+
320+
packed_cu_seqlens = torch.cat(
321+
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]
322+
)
323+
324+
max_seqlen = 0
325+
if len(all_seq_lengths) > 0:
326+
all_seq_lengths = torch.cat(all_seq_lengths)
327+
# device to host sync but only done once per model forward
328+
max_seqlen = all_seq_lengths.max().item()
329+
330+
return VarlenMetadata(
331+
cu_seq_q=packed_cu_seqlens,
332+
cu_seq_k=packed_cu_seqlens,
333+
max_q=max_seqlen,
334+
max_k=max_seqlen,
335+
)

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
qk_rope_head_dim=64,
7373
v_head_dim=128,
7474
mscale=0.70,
75-
use_flex_attn=True,
75+
attn_type="flex",
7676
attn_mask_type="block_causal",
7777
),
7878
"16B": DeepSeekV3ModelArgs(
@@ -97,7 +97,7 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
use_flex_attn=True,
100+
attn_type="flex",
101101
attn_mask_type="block_causal",
102102
),
103103
"236B": DeepSeekV3ModelArgs(
@@ -124,7 +124,7 @@
124124
qk_nope_head_dim=128,
125125
qk_rope_head_dim=64,
126126
v_head_dim=128,
127-
use_flex_attn=True,
127+
attn_type="flex",
128128
attn_mask_type="block_causal",
129129
),
130130
"671B": DeepSeekV3ModelArgs(
@@ -151,7 +151,7 @@
151151
qk_nope_head_dim=128,
152152
qk_rope_head_dim=64,
153153
v_head_dim=128,
154-
use_flex_attn=True,
154+
attn_type="flex",
155155
attn_mask_type="block_causal",
156156
),
157157
}

0 commit comments

Comments
 (0)