Skip to content

Commit ab033dd

Browse files
committed
remove use_flex for all other models
1 parent 4d80f4e commit ab033dd

File tree

26 files changed

+116
-143
lines changed

26 files changed

+116
-143
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _apply_ac_to_transformer_block(
236236
*,
237237
base_fqn: str | None = None,
238238
model_compile_enabled: bool = False,
239-
attn_type: str = "sdpa",
239+
use_flex_attn: bool = False,
240240
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
241241
) -> nn.Module:
242242
valid_ac_modes = ("full", "selective")
@@ -259,7 +259,7 @@ def _apply_ac_to_transformer_block(
259259

260260
if use_op_sac:
261261
op_sac_save_list = op_sac_save_list or set()
262-
if attn_type == "flex":
262+
if use_flex_attn:
263263
"""
264264
For Flex Attention, we need to apply SAC carefully to avoid invalidating
265265
torch.compile. Any torch.compile inside the SAC region will be ignored,
@@ -288,7 +288,7 @@ def apply_ac(
288288
ac_config: ACConfig,
289289
*,
290290
model_compile_enabled: bool = False,
291-
attn_type: str = "sdpa",
291+
use_flex_attn: bool = False,
292292
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
293293
base_folder: str = "",
294294
) -> None:
@@ -302,7 +302,7 @@ def apply_ac(
302302
model (nn.Module): The model to apply activation checkpointing to.
303303
ac_config (ACConfig): The activation checkpointing config.
304304
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
305-
attn_type (str): Attention type (one of [sdpa, varlen, flex])
305+
use_flex_attn (bool): Whether flex attention is enabled for the model.
306306
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
307307
of recomputing.
308308
Returns:
@@ -326,7 +326,7 @@ def apply_ac(
326326
ac_config,
327327
base_fqn=f"layers.{layer_id}",
328328
model_compile_enabled=model_compile_enabled,
329-
attn_type=attn_type,
329+
use_flex_attn=use_flex_attn,
330330
op_sac_save_list=op_sac_save_list,
331331
)
332332
model.layers.register_module(layer_id, transformer_block)

torchtitan/experiments/forge/example_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def forward_backward_step(
161161
inputs = input_dict["input"]
162162
extra_kwargs = {}
163163

164-
if getattr(self.model_args, "use_flex_attn", False):
164+
if getattr(self.model_args, "attn_type", "sdpa") == "flex":
165165
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
166166
input_batch=inputs,
167167
tokenizer=self.tokenizer,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def parallelize_gptoss(
6262
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6363
"""
6464

65-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
66-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
67-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
68-
6965
if parallel_dims.tp_enabled:
7066
if (
7167
job_config.parallelism.enable_async_tensor_parallel
@@ -111,6 +107,8 @@ def parallelize_gptoss(
111107
job_config.compile.enable and "model" in job_config.compile.components
112108
)
113109

110+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
111+
use_flex_attn = attn_type == "flex"
114112
if job_config.activation_checkpoint.mode != "none":
115113
apply_ac(
116114
model,

torchtitan/experiments/gpt_oss/model/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs):
3939
n_kv_heads (int): Number of key-value heads.
4040
sliding_window_size (int): Size of the sliding attention window.
4141
attn_mask_type (str): Type of basic attention mask.
42-
use_flex_attn (bool): Whether to use FlexAttention. Only supports True.
42+
attn_type (bool): Attention type, only supports Flex.
4343
original_seq_len (int): Original sequence length.
4444
rope_theta (float): Base for rotary positional encoding.
4545
rope_factor (float): Scaling factor for extended sequence lengths.
@@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs):
6464
n_kv_heads: int = 8
6565
sliding_window_size: int = 128
6666
attn_mask_type: str = "causal"
67-
use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention
67+
attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention
6868
# yarn
6969
original_seq_len: int = 4096
7070
rope_theta: float = 150000.0

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def parallelize_deepseekv3(
6767

6868
if (
6969
job_config.parallelism.context_parallel_degree > 1
70-
and model.model_args.use_flex_attn
70+
and model.model_args.attn_type != "sdpa"
7171
):
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
72+
raise NotImplementedError("CP support is only supported for SDPA.")
7373

7474
if parallel_dims.tp_enabled:
7575
enable_float8_linear = "float8" in job_config.model.converters
@@ -85,13 +85,11 @@ def parallelize_deepseekv3(
8585
"Currently, float8 tensorwise TP is not tested for deepseekv3"
8686
)
8787

88-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
8988
apply_non_moe_tp(
9089
model,
9190
world_mesh["tp"],
9291
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9392
enable_float8_tensorwise_tp=False,
94-
use_flex_attn=use_flex_attn,
9593
)
9694
maybe_enable_async_tp(job_config, world_mesh["tp"])
9795

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def parallelize_llama(
102102
maybe_enable_async_tp(job_config, tp_mesh)
103103

104104
if job_config.activation_checkpoint.mode != "none":
105-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
105+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
106+
use_flex_attn = attn_type == "flex"
106107
model_compile_enabled = (
107108
job_config.compile.enable and "model" in job_config.compile.components
108109
)

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,17 @@ def parallelize_vlm(
4848
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
4949
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
5050
"""
51-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
52-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
53-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
51+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
52+
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
53+
raise NotImplementedError("CP support is only supported for SDPA.")
5454

5555
if parallel_dims.tp_enabled:
5656
raise NotImplementedError("TP support for VLM training is still in progress.")
5757

5858
model_compile_enabled = (
5959
job_config.compile.enable and "model" in job_config.compile.components
6060
)
61+
use_flex_attn = attn_type == "flex"
6162
if job_config.activation_checkpoint.mode != "none":
6263
apply_ac(
6364
model,

torchtitan/experiments/vlm/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Siglip2ModelArgs:
5353
spatial_merge_size: int = 1
5454

5555
layer_norm_eps: float = 1e-6
56-
use_flex_attn: bool = True
56+
attn_type: str = "flex"
5757
attn_mask_type: str = "causal"
5858

5959

torchtitan/models/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,16 @@ def forward(
6060
xv: torch.Tensor,
6161
head_dim: torch.Tensor,
6262
attention_masks: VarlenMetadata,
63-
is_causal: bool = True,
6463
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
6564
cu_seq_q = attention_masks.cu_seq_q
6665
cu_seq_k = attention_masks.cu_seq_k
6766
max_q = attention_masks.max_q
6867
max_k = attention_masks.max_k
6968

7069
n_local_heads = xq.shape[1]
71-
xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
72-
xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
73-
xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
70+
xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
71+
xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
72+
xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
7473

7574
return VarlenAttentionWrapper._compiled_varlen_attn(
7675
xq_packed,
@@ -325,6 +324,7 @@ def create_varlen_metadata_for_document(
325324
max_seqlen = 0
326325
if len(all_seq_lengths) > 0:
327326
all_seq_lengths = torch.cat(all_seq_lengths)
327+
# device to host sync but only done once per model forward
328328
max_seqlen = all_seq_lengths.max().item()
329329

330330
return VarlenMetadata(

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
qk_rope_head_dim=64,
7373
v_head_dim=128,
7474
mscale=0.70,
75-
use_flex_attn=True,
75+
attn_type="flex",
7676
attn_mask_type="block_causal",
7777
),
7878
"16B": DeepSeekV3ModelArgs(
@@ -97,7 +97,7 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
use_flex_attn=True,
100+
attn_type="flex",
101101
attn_mask_type="block_causal",
102102
),
103103
"236B": DeepSeekV3ModelArgs(
@@ -124,7 +124,7 @@
124124
qk_nope_head_dim=128,
125125
qk_rope_head_dim=64,
126126
v_head_dim=128,
127-
use_flex_attn=True,
127+
attn_type="flex",
128128
attn_mask_type="block_causal",
129129
),
130130
"671B": DeepSeekV3ModelArgs(
@@ -151,7 +151,7 @@
151151
qk_nope_head_dim=128,
152152
qk_rope_head_dim=64,
153153
v_head_dim=128,
154-
use_flex_attn=True,
154+
attn_type="flex",
155155
attn_mask_type="block_causal",
156156
),
157157
}

0 commit comments

Comments
 (0)