diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 8bf3a0249f..6fafa29871 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -346,6 +346,18 @@ def build_features_test_list() -> list[OverrideDefinitions]: "fsdp+flex_attn+per_op_sac", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=4", + "--activation_checkpoint.mode='full'", + "--model.flavor=debugmodel_varlen_attn", + ] + ], + "FSDP+VARLEN_ATTN", + "fsdp+varlen_attn", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7b0b0c81e9..66ad151dd0 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -161,7 +161,7 @@ def forward_backward_step( inputs = input_dict["input"] extra_kwargs = {} - if getattr(self.model_args, "use_flex_attn", False): + if getattr(self.model_args, "attn_type", "sdpa") == "flex": extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 7714d497e4..1070f58aad 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -62,10 +62,6 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -111,6 +107,8 @@ def parallelize_gptoss( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index e78eac4d74..af4c51eadc 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads (int): Number of key-value heads. sliding_window_size (int): Size of the sliding attention window. attn_mask_type (str): Type of basic attention mask. - use_flex_attn (bool): Whether to use FlexAttention. Only supports True. + attn_type (bool): Attention type, only supports Flex. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. @@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads: int = 8 sliding_window_size: int = 128 attn_mask_type: str = "causal" - use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention + attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention # yarn original_seq_len: int = 4096 rope_theta: float = 150000.0 diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 6d415004cc..83e24d7dc1 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -67,9 +67,9 @@ def parallelize_deepseekv3( if ( job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn + and model.model_args.attn_type != "sdpa" ): - raise NotImplementedError("CP support for FlexAttention is still in progress.") + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -85,13 +85,11 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_non_moe_tp( model, world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 67a012a3f7..fb07ef617a 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -102,7 +102,8 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 6a97e4ece1..d418ad6edd 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -48,9 +48,9 @@ def parallelize_vlm( Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: raise NotImplementedError("TP support for VLM training is still in progress.") @@ -58,6 +58,7 @@ def parallelize_vlm( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py index 11b6439ddd..49ba31246b 100644 --- a/torchtitan/experiments/vlm/model/args.py +++ b/torchtitan/experiments/vlm/model/args.py @@ -53,7 +53,7 @@ class Siglip2ModelArgs: spatial_merge_size: int = 1 layer_norm_eps: float = 1e-6 - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "causal" diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 85115fef2b..cc7b87cb20 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -8,7 +8,7 @@ import functools from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, NamedTuple import torch import torch.nn.functional as F @@ -20,10 +20,14 @@ flex_attention, ) +from torch.nn.attention.varlen import varlen_attn + __all__ = [ "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", + "VarlenAttentionWrapper", + "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", "get_sliding_window_mask_mod", @@ -32,6 +36,53 @@ ] +class VarlenMetadata(NamedTuple): + """ + Cumulative sequence positions for queries and keys/values. + + """ + + cu_seq_q: torch.Tensor + cu_seq_k: torch.Tensor + max_q: int + max_k: int + + +class VarlenAttentionWrapper(torch.nn.Module): + _compiled_varlen_attn: ClassVar[Callable] = torch.compile( + varlen_attn, mode="max-autotune-no-cudagraphs" + ) + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + head_dim: torch.Tensor, + attention_masks: VarlenMetadata, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + cu_seq_q = attention_masks.cu_seq_q + cu_seq_k = attention_masks.cu_seq_k + max_q = attention_masks.max_q + max_k = attention_masks.max_k + + n_local_heads = xq.shape[1] + xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + + return VarlenAttentionWrapper._compiled_varlen_attn( + xq_packed, + xk_packed, + xv_packed, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal=True, + ) + + class FlexAttentionWrapper(torch.nn.Module): """Wrapper around `flex_attention` to make it torch.compile and CP compatible. @@ -66,7 +117,6 @@ def forward( # `FlexAttentionWrapper._compiled_flex_attn` is correct. # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation # to convert `lse` to be DTensor. - return FlexAttentionWrapper._compiled_flex_attn( q, k, @@ -226,3 +276,60 @@ def create_attention_mask(*args, **kwargs): arguments. """ return _compiled_create_block_mask(*args, **kwargs) + + +def create_varlen_metadata_for_document( + input_batch: torch.Tensor, eos_id: int +) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch + eos_id: the EOS id marker + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + max_seqlen = 0 + + for b in range(batch_size): + tokens = input_batch[b] + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + eos_positions + 1, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] + ) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..7e2d35a5d9 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -72,7 +72,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( @@ -124,7 +124,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( @@ -151,7 +151,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..a7e1ee0dc5 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -61,9 +61,10 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -84,7 +85,6 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -181,7 +181,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -211,18 +210,11 @@ def apply_non_moe_tp( PrepareModuleInput, ) - if use_flex_attn: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) - else: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..e683905878 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -44,7 +44,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. - use_flex_attn (bool): Whether to use FlexAttention. + attn_type (str): Attention type. attn_mask_type (str): Type of attention mask. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. @@ -76,7 +76,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # yarn @@ -101,10 +101,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..7d7635a4ad 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -184,11 +184,12 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, @@ -245,14 +246,15 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask) - output = self.inner_attention( - q, k, v, block_mask=attention_masks, scale=self.softmax_scale - ) - else: - assert attention_masks is None - output = self.inner_attention(q, k, v, scale=self.softmax_scale) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + case _: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 191588ad9e..75ab234ebc 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -36,7 +36,16 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - use_flex_attn=True, + attn_type="flex", + attn_mask_type="block_causal", + ), + "debugmodel_varlen_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + attn_type="varlen", attn_mask_type="block_causal", ), "8B": TransformerModelArgs( @@ -48,6 +57,28 @@ multiple_of=1024, rope_theta=500000, ), + "8B_flex": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + attn_type="flex", + attn_mask_type="block_causal", + ), + "8B_varlen": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + attn_type="varlen", + attn_mask_type="block_causal", + ), "70B": TransformerModelArgs( dim=8192, n_layers=80, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 86ac3a6dfe..b517e5c15f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -67,10 +67,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -95,6 +91,8 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index d83fb83102..81680074eb 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from torch import nn - from torchtitan.config import JobConfig from torchtitan.models.utils import get_dense_model_nparams_and_flops from torchtitan.protocols.model import BaseModelArgs @@ -43,7 +42,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 0 @@ -55,7 +54,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): raise NotImplementedError( "CP support for FlexAttention is still in progress." ) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 124153f14c..74b862bf76 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -16,10 +16,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -191,11 +194,14 @@ def __init__(self, model_args: TransformerModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "varlen": + self.inner_attention = VarlenAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -240,16 +246,24 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - assert ( - isinstance(attention_masks, BlockMask) or attention_masks is None - ), attention_masks - - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case "varlen": + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + ) + case "sdpa": + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.transpose( 1, 2 @@ -453,13 +467,14 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: case "causal": B = 1 @@ -470,10 +485,36 @@ def get_attention_masks( raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) + return create_attention_mask( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer, extra_inputs + ) + case "varlen": + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id + ) + case _: + raise NotImplementedError( + "Only varlen and flex attn masks are supported" + ) + def forward( self, tokens: torch.Tensor, @@ -497,7 +538,6 @@ def forward( for layer in self.layers.values(): h = layer(h, self.freqs_cis, attention_masks=attention_masks) - h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..b8bd9a4484 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -67,7 +67,7 @@ rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx16e_irope": TransformerModelArgs( @@ -83,7 +83,7 @@ moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx128e_irope": TransformerModelArgs( @@ -96,7 +96,7 @@ rope_theta=500000, moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9911ecdfd0..01ce9d543b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -75,10 +75,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -117,6 +113,8 @@ def parallelize_llama( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 7fcc9871f5..a277ca382e 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is @@ -76,10 +76,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index c8241b84de..6b9d2d2d9e 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -202,11 +202,12 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -217,7 +218,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - attention_masks: AttentionMasksType | None, + attention_masks: AttentionMasksType, ): """ Forward pass of the attention module. @@ -252,7 +253,7 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: + if self.attn_type == "flex": assert isinstance(attention_masks, dict), attention_masks attention_mask = attention_masks["rope" if self.use_rope else "nope"] output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..74254081b6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -59,9 +59,9 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -112,7 +112,7 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=attn_type == "flex", op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..2def3a949a 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,7 +36,7 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 151645 diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a4f0a59844..89296ed98d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -143,7 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 - self.use_flex_attn = getattr(model_args, "use_flex_attn", False) + self.attn_type = getattr(model_args, "attn_type", "sdpa") # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -167,10 +167,11 @@ def __init__(self, model_args: Qwen3ModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -226,12 +227,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case _: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65b..4cb193c31a 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -16,9 +16,10 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig +from torchtitan.models.attention import VarlenMetadata -AttentionMasksType = dict[str, BlockMask] | BlockMask +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata @dataclass diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..4d2f13d6da 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -454,7 +454,8 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - if getattr(self.model_args, "use_flex_attn", False): + attn_type = getattr(self.model_args, "attn_type", "sdpa") + if attn_type in ["flex", "varlen"]: extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer,