Skip to content

Commit 4d80f4e

Browse files
committed
remove use_flex_attn
1 parent 0d32d5a commit 4d80f4e

File tree

5 files changed

+38
-46
lines changed

5 files changed

+38
-46
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _apply_ac_to_transformer_block(
236236
*,
237237
base_fqn: str | None = None,
238238
model_compile_enabled: bool = False,
239-
use_flex_attn: bool = False,
239+
attn_type: str = "sdpa",
240240
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
241241
) -> nn.Module:
242242
valid_ac_modes = ("full", "selective")
@@ -259,7 +259,7 @@ def _apply_ac_to_transformer_block(
259259

260260
if use_op_sac:
261261
op_sac_save_list = op_sac_save_list or set()
262-
if use_flex_attn:
262+
if attn_type == "flex":
263263
"""
264264
For Flex Attention, we need to apply SAC carefully to avoid invalidating
265265
torch.compile. Any torch.compile inside the SAC region will be ignored,
@@ -288,7 +288,7 @@ def apply_ac(
288288
ac_config: ACConfig,
289289
*,
290290
model_compile_enabled: bool = False,
291-
use_flex_attn: bool = False,
291+
attn_type: str = "sdpa",
292292
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
293293
base_folder: str = "",
294294
) -> None:
@@ -302,7 +302,7 @@ def apply_ac(
302302
model (nn.Module): The model to apply activation checkpointing to.
303303
ac_config (ACConfig): The activation checkpointing config.
304304
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
305-
use_flex_attn (bool): Whether flex attention is enabled for the model.
305+
attn_type (str): Attention type (one of [sdpa, varlen, flex])
306306
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
307307
of recomputing.
308308
Returns:
@@ -326,7 +326,7 @@ def apply_ac(
326326
ac_config,
327327
base_fqn=f"layers.{layer_id}",
328328
model_compile_enabled=model_compile_enabled,
329-
use_flex_attn=use_flex_attn,
329+
attn_type=attn_type,
330330
op_sac_save_list=op_sac_save_list,
331331
)
332332
model.layers.register_module(layer_id, transformer_block)

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ def parallelize_llama(
6767
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6868
"""
6969

70-
attn_type = getattr(model.model_args, "attention_type", False)
71-
use_flex_attn = attn_type == "flex"
72-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
73-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
74-
7570
if parallel_dims.tp_enabled:
7671
enable_float8_linear = "float8" in job_config.model.converters
7772
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
@@ -96,12 +91,13 @@ def parallelize_llama(
9691
job_config.compile.enable and "model" in job_config.compile.components
9792
)
9893

94+
attn_type = getattr(model.model_args, "attention_type", False)
9995
if job_config.activation_checkpoint.mode != "none":
10096
apply_ac(
10197
model,
10298
job_config.activation_checkpoint,
10399
model_compile_enabled=model_compile_enabled,
104-
use_flex_attn=use_flex_attn,
100+
attn_type=attn_type,
105101
op_sac_save_list=_op_sac_save_list,
106102
base_folder=job_config.job.dump_folder,
107103
)

torchtitan/models/llama3/model/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs):
4444
# `False`, each uses the total number of transformer blocks
4545
depth_init: bool = True
4646

47-
attention_type: Literal["flex", "varlen"] = None
47+
attention_type: Literal["flex", "varlen"] = "sdpa"
4848
attn_mask_type: str = "causal"
4949
eos_id: int = 0
5050

@@ -58,7 +58,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5858

5959
if (
6060
job_config.parallelism.context_parallel_degree > 1
61-
and self.attention_type == "flex"
61+
and self.attention_type != "sdpa"
6262
):
6363
raise NotImplementedError(
6464
"CP support for FlexAttention is still in progress."

torchtitan/models/llama3/model/model.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,14 @@ def __init__(self, model_args: TransformerModelArgs):
194194
model_args.n_heads * self.head_dim, model_args.dim, bias=False
195195
)
196196

197-
self.use_flex_attn = model_args.attention_type == "flex"
198-
self.use_varlen_attn = model_args.attention_type == "varlen"
199-
if self.use_flex_attn:
200-
self.inner_attention = FlexAttentionWrapper()
201-
elif self.use_varlen_attn:
202-
self.inner_attention = VarlenAttentionWrapper()
203-
else:
204-
self.inner_attention = ScaledDotProductAttentionWrapper()
197+
self.attn_type = model_args.attention_type
198+
match self.attn_type:
199+
case "flex":
200+
self.inner_attention = FlexAttentionWrapper()
201+
case "varlen":
202+
self.inner_attention = VarlenAttentionWrapper()
203+
case _:
204+
self.inner_attention = ScaledDotProductAttentionWrapper()
205205

206206
def init_weights(self, init_std: float):
207207
for linear in (self.wq, self.wk, self.wv):
@@ -246,22 +246,23 @@ def forward(
246246
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
247247
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
248248

249-
if self.use_flex_attn:
250-
assert isinstance(attention_masks, BlockMask), attention_masks
251-
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
252-
elif self.use_varlen_attn:
253-
assert isinstance(attention_masks, VarlenMetadata), attention_masks
254-
output = self.inner_attention(
255-
xq,
256-
xk,
257-
xv,
258-
self.head_dim,
259-
attention_masks,
260-
is_causal=True,
261-
)
262-
else:
263-
assert attention_masks is None
264-
output = self.inner_attention(xq, xk, xv)
249+
match self.attn_type:
250+
case "flex":
251+
assert isinstance(attention_masks, BlockMask), attention_masks
252+
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
253+
case "varlen":
254+
assert isinstance(attention_masks, VarlenMetadata), attention_masks
255+
output = self.inner_attention(
256+
xq,
257+
xk,
258+
xv,
259+
self.head_dim,
260+
attention_masks,
261+
is_causal=True,
262+
)
263+
case _:
264+
assert attention_masks is None
265+
output = self.inner_attention(xq, xk, xv)
265266

266267
output = output.transpose(
267268
1, 2
@@ -468,7 +469,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
468469
def _get_flex_attention_masks(
469470
self,
470471
input_batch: torch.Tensor,
471-
eos_id: int,
472+
tokenizer: BaseTokenizer,
472473
extra_inputs: dict[str, torch.Tensor] | None = None,
473474
) -> AttentionMasksType:
474475
mask_mods = [get_causal_mask_mod()]
@@ -478,7 +479,7 @@ def _get_flex_attention_masks(
478479
B = 1
479480
case "block_causal":
480481
B = input_batch.shape[0]
481-
mask_mods.append(get_document_mask_mod(input_batch, eos_id))
482+
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
482483
case _:
483484
raise ValueError(
484485
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
@@ -505,7 +506,7 @@ def get_attention_masks(
505506
match self.model_args.attention_type:
506507
case "flex":
507508
return self._get_flex_attention_masks(
508-
input_batch, tokenizer.eos_id, extra_inputs
509+
input_batch, tokenizer, extra_inputs
509510
)
510511
case "varlen":
511512
return self._get_varlen_attention_masks(

torchtitan/train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,7 @@ def post_dataloading_process(
455455
extra_kwargs: dict[str, Any] = {}
456456

457457
attn_type = getattr(self.model_args, "attention_type", False)
458-
use_varlen_attn = attn_type == "varlen"
459-
use_flex_attn = (
460-
getattr(self.model_args, "use_flex_attn", False) or attn_type == "flex"
461-
)
462-
463-
if use_flex_attn or use_varlen_attn:
458+
if attn_type in ["flex", "varlen"]:
464459
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
465460
input_batch=inputs,
466461
tokenizer=self.tokenizer,

0 commit comments

Comments
 (0)