Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def forward_backward_step(
inputs = input_dict["input"]
extra_kwargs = {}

if getattr(self.model_args, "use_flex_attn", False):
if getattr(self.model_args, "attn_type", "sdpa") == "flex":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"varlen" should also work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc this isn't limited to llama3, ill add varlen after more thorough testing for the other models

extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
Expand Down
6 changes: 2 additions & 4 deletions torchtitan/experiments/gpt_oss/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def parallelize_gptoss(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
Expand Down Expand Up @@ -111,6 +107,8 @@ def parallelize_gptoss(
job_config.compile.enable and "model" in job_config.compile.components
)

attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/gpt_oss/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs):
n_kv_heads (int): Number of key-value heads.
sliding_window_size (int): Size of the sliding attention window.
attn_mask_type (str): Type of basic attention mask.
use_flex_attn (bool): Whether to use FlexAttention. Only supports True.
attn_type (bool): Attention type, only supports Flex.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
Expand All @@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs):
n_kv_heads: int = 8
sliding_window_size: int = 128
attn_mask_type: str = "causal"
use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention
attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention
# yarn
original_seq_len: int = 4096
rope_theta: float = 150000.0
Expand Down
7 changes: 3 additions & 4 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def parallelize_deepseekv3(

if (
job_config.parallelism.context_parallel_degree > 1
and model.model_args.use_flex_attn
and model.model_args.attn_type != "sdpa"
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")
raise NotImplementedError("CP support is only supported for SDPA.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
Expand All @@ -85,13 +85,12 @@ def parallelize_deepseekv3(
"Currently, float8 tensorwise TP is not tested for deepseekv3"
)

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
attn_type = getattr(model.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this line

apply_non_moe_tp(
model,
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
use_flex_attn=use_flex_attn,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

Expand Down
3 changes: 2 additions & 1 deletion torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def parallelize_llama(
maybe_enable_async_tp(job_config, tp_mesh)

if job_config.activation_checkpoint.mode != "none":
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
Expand Down
7 changes: 4 additions & 3 deletions torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ def parallelize_vlm(
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")
attn_type = getattr(model.model_args, "attn_type", "sdpa")
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
raise NotImplementedError("CP support is only supported for SDPA.")

if parallel_dims.tp_enabled:
raise NotImplementedError("TP support for VLM training is still in progress.")

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
use_flex_attn = attn_type == "flex"
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/vlm/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Siglip2ModelArgs:
spatial_merge_size: int = 1

layer_norm_eps: float = 1e-6
use_flex_attn: bool = True
attn_type: str = "flex"
attn_mask_type: str = "causal"


Expand Down
111 changes: 109 additions & 2 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import functools
from collections.abc import Callable
from typing import ClassVar
from typing import ClassVar, NamedTuple

import torch
import torch.nn.functional as F
Expand All @@ -20,10 +20,14 @@
flex_attention,
)

from torch.nn.attention.varlen import varlen_attn


__all__ = [
"FlexAttentionWrapper",
"ScaledDotProductAttentionWrapper",
"VarlenAttentionWrapper",
"VarlenMetadata",
"get_causal_mask_mod",
"get_document_mask_mod",
"get_sliding_window_mask_mod",
Expand All @@ -32,6 +36,53 @@
]


class VarlenMetadata(NamedTuple):
"""
Cumulative sequence positions for queries and keys/values.

"""

cu_seq_q: torch.Tensor
cu_seq_k: torch.Tensor
max_q: int
max_k: int


class VarlenAttentionWrapper(torch.nn.Module):
_compiled_varlen_attn: ClassVar[Callable] = torch.compile(
varlen_attn, mode="max-autotune-no-cudagraphs"
)

def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
head_dim: torch.Tensor,
attention_masks: VarlenMetadata,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
cu_seq_q = attention_masks.cu_seq_q
cu_seq_k = attention_masks.cu_seq_k
max_q = attention_masks.max_q
max_k = attention_masks.max_k

n_local_heads = xq.shape[1]
xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim)
xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim)

return VarlenAttentionWrapper._compiled_varlen_attn(
xq_packed,
xk_packed,
xv_packed,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
is_causal=True,
)


class FlexAttentionWrapper(torch.nn.Module):
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.

Expand Down Expand Up @@ -66,7 +117,6 @@ def forward(
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
# to convert `lse` to be DTensor.

return FlexAttentionWrapper._compiled_flex_attn(
q,
k,
Expand Down Expand Up @@ -226,3 +276,60 @@ def create_attention_mask(*args, **kwargs):
arguments.
"""
return _compiled_create_block_mask(*args, **kwargs)


def create_varlen_metadata_for_document(
input_batch: torch.Tensor, eos_id: int
) -> VarlenMetadata:
"""
Creates cumulative sequence length indices needed for variable length attention

Args:
input_batch
eos_id: the EOS id marker

Returns:
VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len
"""
batch_size, seq_len = input_batch.shape
device = input_batch.device
cu_seqlens_list, all_seq_lengths = [], []
offset = 0
max_seqlen = 0

for b in range(batch_size):
tokens = input_batch[b]
eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)
sample_cu_seqlens = torch.cat(
[
torch.tensor([0], dtype=torch.int32, device=device),
eos_positions + 1,
torch.tensor([seq_len], dtype=torch.int32, device=device),
]
)
sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)

seq_lengths = torch.diff(sample_cu_seqlens)
all_seq_lengths.append(seq_lengths)

cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
cu_seqlens_list.append(cu_seqlens_adjusted)

offset += seq_len

packed_cu_seqlens = torch.cat(
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]
)

max_seqlen = 0
if len(all_seq_lengths) > 0:
all_seq_lengths = torch.cat(all_seq_lengths)
# device to host sync but only done once per model forward
max_seqlen = all_seq_lengths.max().item()

return VarlenMetadata(
cu_seq_q=packed_cu_seqlens,
cu_seq_k=packed_cu_seqlens,
max_q=max_seqlen,
max_k=max_seqlen,
)
8 changes: 4 additions & 4 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_type="flex",
attn_mask_type="block_causal",
),
"16B": DeepSeekV3ModelArgs(
Expand All @@ -97,7 +97,7 @@
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_type="flex",
attn_mask_type="block_causal",
),
"236B": DeepSeekV3ModelArgs(
Expand All @@ -124,7 +124,7 @@
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
use_flex_attn=True,
attn_type="flex",
attn_mask_type="block_causal",
),
"671B": DeepSeekV3ModelArgs(
Expand All @@ -151,7 +151,7 @@
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
use_flex_attn=True,
attn_type="flex",
attn_mask_type="block_causal",
),
}
Expand Down
26 changes: 9 additions & 17 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def parallelize_deepseekv3(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")
attn_type = getattr(model.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in python 3.11+ strenum seems like a good fit for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchTitan still sticks to 3.10 afaik.

use_flex_attn = attn_type == "flex"
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
raise NotImplementedError("CP support is only supported for SDPA.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
Expand All @@ -84,7 +85,6 @@ def parallelize_deepseekv3(
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
use_flex_attn=use_flex_attn,
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

Expand Down Expand Up @@ -181,7 +181,6 @@ def apply_non_moe_tp(
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
use_flex_attn: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down Expand Up @@ -211,18 +210,11 @@ def apply_non_moe_tp(
PrepareModuleInput,
)

if use_flex_attn:
attention_kernel_plan = prepare_module_input(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
use_local_output=True,
)
else:
attention_kernel_plan = prepare_module_input(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
use_local_output=True,
)
Comment on lines -214 to -225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing duplicated code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya this was there before i removed it per #2000 (comment)

attention_kernel_plan = prepare_module_input(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
use_local_output=True,
)
# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
Expand Down
10 changes: 4 additions & 6 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
use_flex_attn (bool): Whether to use FlexAttention.
attn_type (str): Attention type.
attn_mask_type (str): Type of attention mask.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
Expand Down Expand Up @@ -76,7 +76,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
use_flex_attn: bool = False
attn_type: str = "sdpa"
attn_mask_type: str = "causal"

# yarn
Expand All @@ -101,10 +101,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.moe_args.use_grouped_mm = False

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
raise NotImplementedError("CP support is only supported for SDPA.")

self.moe_args._debug_force_load_balance = (
job_config.debug.moe_force_load_balance
Expand Down
Loading
Loading