Skip to content

Commit ca0efc0

Browse files
committed
remove use_flex for all other models
1 parent 4d80f4e commit ca0efc0

File tree

24 files changed

+96
-99
lines changed

24 files changed

+96
-99
lines changed

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_bw_flops(model_fn):
8484
model_selective_ac,
8585
ac_config_no_force,
8686
model_compile_enabled=False,
87-
use_flex_attn=False,
87+
attn_type="sdpa",
8888
op_sac_save_list=_op_sac_save_list,
8989
)
9090
flops_selective_ac = get_bw_flops(model_selective_ac)
@@ -102,7 +102,7 @@ def get_bw_flops(model_fn):
102102
model_with_force_first,
103103
ac_config_with_force_first,
104104
model_compile_enabled=False,
105-
use_flex_attn=False,
105+
attn_type="sdpa",
106106
op_sac_save_list=_op_sac_save_list,
107107
)
108108
flops_with_force_first = get_bw_flops(model_with_force_first)
@@ -119,7 +119,7 @@ def get_bw_flops(model_fn):
119119
model_with_force_last,
120120
ac_config_with_force_last,
121121
model_compile_enabled=False,
122-
use_flex_attn=False,
122+
attn_type="sdpa",
123123
op_sac_save_list=_op_sac_save_list,
124124
)
125125
flops_with_force_last = get_bw_flops(model_with_force_last)
@@ -134,7 +134,7 @@ def get_bw_flops(model_fn):
134134
model_with_full_ac,
135135
ac_config_full_ac,
136136
model_compile_enabled=False,
137-
use_flex_attn=False,
137+
attn_type="sdpa",
138138
op_sac_save_list=_op_sac_save_list,
139139
)
140140
flops_full_ac = get_bw_flops(model_with_full_ac)
@@ -177,7 +177,7 @@ def get_act_mem(model_fn):
177177
model_selective_ac,
178178
ac_config_no_force,
179179
model_compile_enabled=False,
180-
use_flex_attn=False,
180+
attn_type="sdpa",
181181
op_sac_save_list=_op_sac_save_list,
182182
)
183183
mem_selective_ac = get_act_mem(model_selective_ac)
@@ -194,7 +194,7 @@ def get_act_mem(model_fn):
194194
model_with_force_first,
195195
ac_config_with_force_first,
196196
model_compile_enabled=False,
197-
use_flex_attn=False,
197+
attn_type="sdpa",
198198
op_sac_save_list=_op_sac_save_list,
199199
)
200200
mem_with_force_first = get_act_mem(model_with_force_first)
@@ -210,7 +210,7 @@ def get_act_mem(model_fn):
210210
model_with_force_last,
211211
ac_config_with_force_last,
212212
model_compile_enabled=False,
213-
use_flex_attn=False,
213+
attn_type="sdpa",
214214
op_sac_save_list=_op_sac_save_list,
215215
)
216216
mem_with_force_last = get_act_mem(model_with_force_last)
@@ -224,7 +224,7 @@ def get_act_mem(model_fn):
224224
model_with_full_ac,
225225
ac_config_full_ac,
226226
model_compile_enabled=False,
227-
use_flex_attn=False,
227+
attn_type="sdpa",
228228
op_sac_save_list=_op_sac_save_list,
229229
)
230230
mem_full_ac = get_act_mem(model_with_full_ac)
@@ -251,7 +251,7 @@ def test_correctness(self):
251251
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
252252
),
253253
model_compile_enabled=False,
254-
use_flex_attn=False,
254+
attn_type="sdpa",
255255
op_sac_save_list=_op_sac_save_list,
256256
)
257257
model_force_first = ToyModule()
@@ -264,7 +264,7 @@ def test_correctness(self):
264264
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
265265
),
266266
model_compile_enabled=False,
267-
use_flex_attn=False,
267+
attn_type="sdpa",
268268
op_sac_save_list=_op_sac_save_list,
269269
)
270270

@@ -278,7 +278,7 @@ def test_correctness(self):
278278
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
279279
),
280280
model_compile_enabled=False,
281-
use_flex_attn=False,
281+
attn_type="sdpa",
282282
op_sac_save_list=_op_sac_save_list,
283283
)
284284

torchtitan/experiments/forge/example_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def forward_backward_step(
161161
inputs = input_dict["input"]
162162
extra_kwargs = {}
163163

164-
if getattr(self.model_args, "use_flex_attn", False):
164+
if getattr(self.model_args, "attn_type", "sdpa") == "flex":
165165
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
166166
input_batch=inputs,
167167
tokenizer=self.tokenizer,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def parallelize_gptoss(
6262
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6363
"""
6464

65-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
66-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
65+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
66+
if job_config.parallelism.context_parallel_degree > 1 and attn_type == "flex":
6767
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6868

6969
if parallel_dims.tp_enabled:
@@ -116,7 +116,7 @@ def parallelize_gptoss(
116116
model,
117117
job_config.activation_checkpoint,
118118
model_compile_enabled=model_compile_enabled,
119-
use_flex_attn=use_flex_attn,
119+
attn_type=attn_type,
120120
op_sac_save_list=_op_sac_save_list,
121121
)
122122

torchtitan/experiments/gpt_oss/model/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs):
3939
n_kv_heads (int): Number of key-value heads.
4040
sliding_window_size (int): Size of the sliding attention window.
4141
attn_mask_type (str): Type of basic attention mask.
42-
use_flex_attn (bool): Whether to use FlexAttention. Only supports True.
42+
attn_type (bool): Attention type, only supports Flex.
4343
original_seq_len (int): Original sequence length.
4444
rope_theta (float): Base for rotary positional encoding.
4545
rope_factor (float): Scaling factor for extended sequence lengths.
@@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs):
6464
n_kv_heads: int = 8
6565
sliding_window_size: int = 128
6666
attn_mask_type: str = "causal"
67-
use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention
67+
attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention
6868
# yarn
6969
original_seq_len: int = 4096
7070
rope_theta: float = 150000.0

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def parallelize_deepseekv3(
6767

6868
if (
6969
job_config.parallelism.context_parallel_degree > 1
70-
and model.model_args.use_flex_attn
70+
and model.model_args.attn_type == "flex"
7171
):
7272
raise NotImplementedError("CP support for FlexAttention is still in progress.")
7373

@@ -85,13 +85,13 @@ def parallelize_deepseekv3(
8585
"Currently, float8 tensorwise TP is not tested for deepseekv3"
8686
)
8787

88-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
88+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
8989
apply_non_moe_tp(
9090
model,
9191
world_mesh["tp"],
9292
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9393
enable_float8_tensorwise_tp=False,
94-
use_flex_attn=use_flex_attn,
94+
attn_type=attn_type,
9595
)
9696
maybe_enable_async_tp(job_config, world_mesh["tp"])
9797

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,15 @@ def parallelize_llama(
102102
maybe_enable_async_tp(job_config, tp_mesh)
103103

104104
if job_config.activation_checkpoint.mode != "none":
105-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
105+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
106106
model_compile_enabled = (
107107
job_config.compile.enable and "model" in job_config.compile.components
108108
)
109109
apply_ac(
110110
model,
111111
job_config.activation_checkpoint,
112112
model_compile_enabled=model_compile_enabled,
113-
use_flex_attn=use_flex_attn,
113+
attn_type=attn_type,
114114
op_sac_save_list=_op_sac_save_list,
115115
base_folder=job_config.job.dump_folder,
116116
)

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def parallelize_vlm(
4848
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
4949
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
5050
"""
51-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
52-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
51+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
52+
if job_config.parallelism.context_parallel_degree > 1 and attn_type == "flex":
5353
raise NotImplementedError("CP support for FlexAttention is still in progress.")
5454

5555
if parallel_dims.tp_enabled:
@@ -63,7 +63,7 @@ def parallelize_vlm(
6363
model,
6464
job_config.activation_checkpoint,
6565
model_compile_enabled=model_compile_enabled,
66-
use_flex_attn=use_flex_attn,
66+
attn_type=attn_type,
6767
op_sac_save_list=_op_sac_save_list,
6868
)
6969
apply_ac(model.encoder, job_config.activation_checkpoint)

torchtitan/experiments/vlm/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Siglip2ModelArgs:
5353
spatial_merge_size: int = 1
5454

5555
layer_norm_eps: float = 1e-6
56-
use_flex_attn: bool = True
56+
attn_type: str = "flex"
5757
attn_mask_type: str = "causal"
5858

5959

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
qk_rope_head_dim=64,
7373
v_head_dim=128,
7474
mscale=0.70,
75-
use_flex_attn=True,
75+
attn_type="flex",
7676
attn_mask_type="block_causal",
7777
),
7878
"16B": DeepSeekV3ModelArgs(
@@ -97,7 +97,7 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
use_flex_attn=True,
100+
attn_type="flex",
101101
attn_mask_type="block_causal",
102102
),
103103
"236B": DeepSeekV3ModelArgs(
@@ -124,7 +124,7 @@
124124
qk_nope_head_dim=128,
125125
qk_rope_head_dim=64,
126126
v_head_dim=128,
127-
use_flex_attn=True,
127+
attn_type="flex",
128128
attn_mask_type="block_causal",
129129
),
130130
"671B": DeepSeekV3ModelArgs(
@@ -151,7 +151,7 @@
151151
qk_nope_head_dim=128,
152152
qk_rope_head_dim=64,
153153
v_head_dim=128,
154-
use_flex_attn=True,
154+
attn_type="flex",
155155
attn_mask_type="block_causal",
156156
),
157157
}

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def parallelize_deepseekv3(
6161
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6262
"""
6363

64-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
65-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
64+
attn_type = getattr(model.model_args, "attn_type", "sdpa")
65+
if job_config.parallelism.context_parallel_degree > 1 and attn_type == "flex":
6666
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6767

6868
if parallel_dims.tp_enabled:
@@ -84,7 +84,7 @@ def parallelize_deepseekv3(
8484
world_mesh["tp"],
8585
loss_parallel=not job_config.parallelism.disable_loss_parallel,
8686
enable_float8_tensorwise_tp=False,
87-
use_flex_attn=use_flex_attn,
87+
attn_type=attn_type,
8888
)
8989
maybe_enable_async_tp(job_config, world_mesh["tp"])
9090

@@ -112,7 +112,7 @@ def parallelize_deepseekv3(
112112
model,
113113
job_config.activation_checkpoint,
114114
model_compile_enabled=model_compile_enabled,
115-
use_flex_attn=use_flex_attn,
115+
attn_type=attn_type,
116116
op_sac_save_list=_op_sac_save_list,
117117
base_folder=job_config.job.dump_folder,
118118
)
@@ -181,7 +181,7 @@ def apply_non_moe_tp(
181181
tp_mesh: DeviceMesh,
182182
loss_parallel: bool,
183183
enable_float8_tensorwise_tp: bool,
184-
use_flex_attn: bool,
184+
attn_type: str,
185185
):
186186
"""Apply tensor parallelism."""
187187
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -211,7 +211,7 @@ def apply_non_moe_tp(
211211
PrepareModuleInput,
212212
)
213213

214-
if use_flex_attn:
214+
if attn_type == "flex":
215215
attention_kernel_plan = prepare_module_input(
216216
input_layouts=(Shard(1), Shard(1), Shard(1)),
217217
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),

0 commit comments

Comments
 (0)