From bd0c0fc1b45a2ab8f6678e8540db145e4e2cdcb0 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Tue, 4 Nov 2025 14:36:58 -0800 Subject: [PATCH 01/11] support varlen_attn for llama3 --- torchtitan/hf_datasets/text_datasets.py | 121 +++++++++++++++++- torchtitan/models/llama3/__init__.py | 8 ++ torchtitan/models/llama3/model/args.py | 1 + torchtitan/models/llama3/model/model.py | 75 ++++++++++- .../llama3/train_configs/debug_model.toml | 6 +- torchtitan/train.py | 13 +- 6 files changed, 216 insertions(+), 8 deletions(-) diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 493cd1abb4..807bc19955 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -19,6 +19,7 @@ from torchtitan.config import JobConfig from torchtitan.hf_datasets import DatasetConfig from torchtitan.tools.logging import logger +from torchtitan.protocols import train_spec def _load_c4_dataset(dataset_path: str, split: str): @@ -66,6 +67,69 @@ def _validate_dataset( logger.info(f"Preparing {dataset_name} dataset from {path}") return path, config.loader, config.sample_processor +def varlen_collate_fn(batch): + """ + Custom collate function for varlen attention. + Collapses batch dimension by packing all samples into a single sequence. + + Args: + batch: List of (input_dict, label) tuples + + Returns: + Packed (input_dict, label) with collapsed batch dimension + """ + if len(batch) == 1: + # Single sample - already packed + input_dict, label = batch[0] + return { + "input": input_dict["input"].unsqueeze(0), # [1, seq_len] + "cu_seq_q": input_dict["cu_seq_q"], + "cu_seq_k": input_dict["cu_seq_k"], + "max_q": input_dict["max_q"], + "max_k": input_dict["max_k"], + }, label.unsqueeze(0) # [1, seq_len] + + # Multiple samples - pack them together + inputs = [] + labels = [] + cu_seqlens_list = [] + offset = 0 + max_seqlen = 0 + + for input_dict, label in batch: + inputs.append(input_dict["input"]) + labels.append(label) + + # Get cu_seqlens from this sample and adjust by offset + cu_seqlens = input_dict["cu_seq_q"] + # Don't include the last boundary (we'll add it at the end) + cu_seqlens_adjusted = cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + # Track maximum sequence length across all samples + max_seqlen = max(max_seqlen, input_dict["max_q"]) + + # Update offset for next sample + offset += len(input_dict["input"]) + + # Concatenate all inputs and labels + packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # Shape: [total_tokens] + packed_label = torch.cat(labels, dim=0).unsqueeze(0) # Shape: [total_tokens] + + # Combine all cu_seqlens and add final boundary + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32)] + ) + + return { + "input": packed_input, + "cu_seq_q": packed_cu_seqlens, + "cu_seq_k": packed_cu_seqlens, + "max_q": max_seqlen, + "max_k": max_seqlen, + }, packed_label + + class HuggingFaceTextDataset(IterableDataset, Stateful): def __init__( @@ -97,6 +161,9 @@ def __init__( self._sample_idx = 0 self._token_buffer: list[int] = [] + self._boundary_buffer: list[int] = [0] + self.use_varlen_attn: bool = False + def _get_data_iter(self): # For map-style datasets, resume by skipping to the correct index # For iterable-style datasets, the underlying iterator already points to the correct index @@ -121,13 +188,52 @@ def __iter__(self): self._token_buffer.extend(sample_tokens) self._sample_idx += 1 + # marks where this current document ends + if self.use_varlen_attn: + self._boundary_buffer.append(len(self._token_buffer)) + while len(self._token_buffer) >= max_buffer_token_len: x = torch.LongTensor(self._token_buffer[:max_buffer_token_len]) + # update tokens to the remaining tokens self._token_buffer = self._token_buffer[max_buffer_token_len:] + input = x[:-1] label = x[1:] - yield {"input": input}, label + + if self.use_varlen_attn: + boundaries_in_window = [ + b for b in self._boundary_buffer + if b <= max_buffer_token_len + ] + + cu_seqlens = torch.tensor(boundaries_in_window, dtype=torch.int32) + + self._boundary_buffer = [ + b - max_buffer_token_len + for b in self._boundary_buffer + if b > max_buffer_token_len + ] + + if not self._boundary_buffer or self._boundary_buffer[0] != 0: + self._boundary_buffer.insert(0, 0) + + cu_seqlens_input = cu_seqlens[cu_seqlens <= len(input)] + if cu_seqlens_input[-1] != len(input): + cu_seqlens_input = torch.cat([cu_seqlens_input, torch.tensor([len(input)], dtype=torch.int32)]) + + seq_lengths = torch.diff(cu_seqlens_input) + max_seqlen = seq_lengths.max().item() if len(seq_lengths) > 0 else self.seq_len + + yield { + "input": input, + "cu_seq_q": cu_seqlens_input, + "cu_seq_k": cu_seqlens_input, + "max_q": max_seqlen, + "max_k": max_seqlen, + }, label + else: + yield {"input": input}, label if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") @@ -145,6 +251,7 @@ def __iter__(self): def load_state_dict(self, state_dict): self._token_buffer = state_dict["token_buffer"] + self._boundary_buffer = state_dict.get("boundary_buffer", [0]) if isinstance(self._data, Dataset): self._sample_idx = state_dict["sample_idx"] @@ -153,7 +260,10 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict = {"token_buffer": self._token_buffer} + _state_dict = { + "token_buffer": self._token_buffer, + "boundary_buffer": self._boundary_buffer, + } if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx @@ -178,6 +288,9 @@ def build_text_dataloader( batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len + model_args = train_spec.get_train_spec(job_config.model.name).model_args[job_config.model.flavor] + use_varlen_attn = getattr(model_args, "use_varlen_attn", False) + hf_ds = HuggingFaceTextDataset( dataset_name=dataset_name, dataset_path=dataset_path, @@ -187,12 +300,16 @@ def build_text_dataloader( dp_world_size=dp_world_size, infinite=infinite, ) + hf_ds.use_varlen_attn = use_varlen_attn + + collate_fn=varlen_collate_fn if use_varlen_attn else None return ParallelAwareDataloader( dataset=hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, batch_size=batch_size, + collate_fn=collate_fn, ) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 191588ad9e..3dc695be67 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -39,6 +39,14 @@ use_flex_attn=True, attn_mask_type="block_causal", ), + "debugmodel_varlen_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + use_varlen_attn=True, + ), "8B": TransformerModelArgs( dim=4096, n_layers=32, diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index d83fb83102..2bbe862837 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -44,6 +44,7 @@ class TransformerModelArgs(BaseModelArgs): depth_init: bool = True use_flex_attn: bool = False + use_varlen_attn: bool = False attn_mask_type: str = "causal" eos_id: int = 0 diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 124153f14c..2b2e8bd491 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -24,6 +24,8 @@ from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol +from torch.nn.attention.varlen import varlen_attn + from .args import RoPEScalingArgs, TransformerModelArgs @@ -192,8 +194,11 @@ def __init__(self, model_args: TransformerModelArgs): ) self.use_flex_attn = model_args.use_flex_attn + self.use_varlen_attn = model_args.use_varlen_attn if self.use_flex_attn: self.inner_attention = FlexAttentionWrapper() + elif self.use_varlen_attn: + self.inner_attention = varlen_attn else: self.inner_attention = ScaledDotProductAttentionWrapper() @@ -202,11 +207,54 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + def _apply_rotary_per_sequence( + self, + xq: torch.Tensor, # [bs, total_tokens, n_heads, head_dim] + xk: torch.Tensor, + freqs_cis: torch.Tensor, + cu_seqlens: torch.Tensor, # [num_sequences + 1] + ): + xq = xq.squeeze(0) # [total_tokens, n_heads, head_dim] + xk = xk.squeeze(0) + + xq_out_list = [] + xk_out_list = [] + + for i in range(len(cu_seqlens) - 1): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + seq_len = end_idx - start_idx + + # extract this sequence + xq_seq = xq[start_idx:end_idx] # [seq_len, n_heads, head_dim] + xk_seq = xk[start_idx:end_idx] + + # get freqs_cis for this sequence length (positions 0 to seq_len-1) + freqs_cis_seq = freqs_cis[:seq_len] # [seq_len, head_dim/2] + + # apply RoPE to this sequence + xq_seq_rope, xk_seq_rope = apply_rotary_emb( + xq_seq.unsqueeze(0), # add batch dim back + xk_seq.unsqueeze(0), + freqs_cis=freqs_cis_seq + ) + + xq_out_list.append(xq_seq_rope.squeeze(0)) + xk_out_list.append(xk_seq_rope.squeeze(0)) + + # concatenate all sequences back together + xq_out = torch.cat(xq_out_list, dim=0) # [total_tokens, n_heads, head_dim] + xk_out = torch.cat(xk_out_list, dim=0) + + # add batch dimension back + return xq_out.unsqueeze(0), xk_out.unsqueeze(0) + def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + **kwargs ): """ Forward pass of the attention module. @@ -230,7 +278,12 @@ def forward( xk = xk.view(bs, seqlen, -1, self.head_dim) xv = xv.view(bs, seqlen, -1, self.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + if self.use_varlen_attn: + cu_seq_q = kwargs.get("cu_seq_q") + assert(cu_seq_q is not None) + xq, xk = self._apply_rotary_per_sequence(xq, xk, freqs_cis, cu_seq_q) + else: + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -247,6 +300,20 @@ def forward( if self.use_flex_attn: assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + elif self.use_varlen_attn: + cu_seq_q = kwargs.get("cu_seq_q") + cu_seq_k = kwargs.get("cu_seq_k") + max_q = kwargs.get("max_q") + max_k = kwargs.get("max_k") + + n_local_heads = xq.shape[1] + + xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + + + output = self.inner_attention(xq_packed, xk_packed, xv_packed, cu_seq_q, cu_seq_k, max_q, max_k) else: assert attention_masks is None output = self.inner_attention(xq, xk, xv) @@ -346,6 +413,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + **kwargs ): """ Perform a forward pass through the TransformerBlock. @@ -358,7 +426,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks, **kwargs) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -478,6 +546,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + **kwargs ): """ Perform a forward pass through the Transformer model. @@ -496,7 +565,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer(h, self.freqs_cis, attention_masks=attention_masks, **kwargs) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 7760667edd..f7dafbc042 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -19,7 +19,8 @@ enable_wandb = false [model] name = "llama3" -flavor = "debugmodel" +flavor = "debugmodel_varlen_attn" +# flavor = "debugmodel_flex_attn" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] @@ -78,3 +79,6 @@ enable = false dataset = "c4_validation" freq = 5 steps = 10 + +[debug] +seed = 42 diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..05c70685dc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,9 +12,9 @@ import torch -from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.protocols.train_spec as train_spec_module + +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -396,6 +396,9 @@ def batch_generator( # entire step will not be executed. raise DataloaderExhaustedError() from ex input_dict, labels = batch + + # print(f"input_dict: {input_dict["input"]}") + # print(f"labels: {labels}") ntokens_batch = labels.numel() self.ntokens_seen += ntokens_batch self.metrics_processor.ntokens_since_last_log += ntokens_batch @@ -461,6 +464,12 @@ def post_dataloading_process( extra_inputs=extra_inputs, ) + if getattr(self.model_args, "use_varlen_attn", False): + extra_kwargs["cu_seq_q"] = extra_inputs.pop("cu_seq_q", None) + extra_kwargs["cu_seq_k"] = extra_inputs.pop("cu_seq_k", None) + extra_kwargs["max_q"] = extra_inputs.pop("max_q", None) + extra_kwargs["max_k"] = extra_inputs.pop("max_k", None) + return inputs, labels, extra_inputs, extra_kwargs def forward_backward_step( From ea085c188dfc01e68bddeee7b3605d9e614f32c9 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 6 Nov 2025 15:45:47 -0800 Subject: [PATCH 02/11] testing --- torchtitan/models/llama3/__init__.py | 3 +++ torchtitan/models/llama3/train_configs/llama3_8b.toml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 3dc695be67..5fbcf1b139 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -55,6 +55,9 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, + # use_flex_attn=True, + # attn_mask_type="block_causal", + use_varlen_attn=True, ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ef86d783bf..7aa53dfb5c 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -68,3 +68,6 @@ enable = false dataset = "c4_validation" freq = 500 steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 + +[debug] +seed = 42 From b3f723d5319d19a7cbd0d5e038dec4386e13f83b Mon Sep 17 00:00:00 2001 From: Angel Li Date: Tue, 11 Nov 2025 15:48:27 -0800 Subject: [PATCH 03/11] fixing is_causal and .item() --- torchtitan/hf_datasets/text_datasets.py | 7 ++++-- torchtitan/models/attention.py | 24 +++++++++++++++++++ torchtitan/models/llama3/model/model.py | 11 +++++---- .../llama3/train_configs/llama3_8b.toml | 2 +- torchtitan/tools/profiling.py | 1 + torchtitan/train.py | 19 ++++++++++----- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 807bc19955..240bc859bf 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -130,7 +130,6 @@ def varlen_collate_fn(batch): }, packed_label - class HuggingFaceTextDataset(IterableDataset, Stateful): def __init__( self, @@ -190,6 +189,7 @@ def __iter__(self): # marks where this current document ends if self.use_varlen_attn: + # if self.use_varlen_attn or self.use_flex_attn: self._boundary_buffer.append(len(self._token_buffer)) while len(self._token_buffer) >= max_buffer_token_len: @@ -198,16 +198,19 @@ def __iter__(self): # update tokens to the remaining tokens self._token_buffer = self._token_buffer[max_buffer_token_len:] - input = x[:-1] + input = x[:-1] # print device here label = x[1:] if self.use_varlen_attn: + # if self.use_varlen_attn or self.use_flex_attn: boundaries_in_window = [ b for b in self._boundary_buffer if b <= max_buffer_token_len ] cu_seqlens = torch.tensor(boundaries_in_window, dtype=torch.int32) + # print device here + self._boundary_buffer = [ b - max_buffer_token_len diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 85115fef2b..c356ee7ea0 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,10 +20,13 @@ flex_attention, ) +from torch.nn.attention.varlen import varlen_attn + __all__ = [ "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", + "VarlenAttentionWrapper", "get_causal_mask_mod", "get_document_mask_mod", "get_sliding_window_mask_mod", @@ -31,6 +34,27 @@ "create_attention_mask", ] +class VarlenAttentionWrapper(torch.nn.Module): + _compiled_varlen_attn: ClassVar[Callable] = torch.compile( + varlen_attn, mode="max-autotune-no-cudagraphs" + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = True, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + + return VarlenAttentionWrapper._compiled_varlen_attn( + q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True + ) + class FlexAttentionWrapper(torch.nn.Module): """Wrapper around `flex_attention` to make it torch.compile and CP compatible. diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 2b2e8bd491..3a5adb82c8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -212,7 +212,7 @@ def _apply_rotary_per_sequence( xq: torch.Tensor, # [bs, total_tokens, n_heads, head_dim] xk: torch.Tensor, freqs_cis: torch.Tensor, - cu_seqlens: torch.Tensor, # [num_sequences + 1] + cu_seqlens: list, # [num_sequences + 1] ): xq = xq.squeeze(0) # [total_tokens, n_heads, head_dim] xk = xk.squeeze(0) @@ -221,8 +221,8 @@ def _apply_rotary_per_sequence( xk_out_list = [] for i in range(len(cu_seqlens) - 1): - start_idx = cu_seqlens[i].item() - end_idx = cu_seqlens[i + 1].item() + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] seq_len = end_idx - start_idx # extract this sequence @@ -279,8 +279,9 @@ def forward( xv = xv.view(bs, seqlen, -1, self.head_dim) if self.use_varlen_attn: - cu_seq_q = kwargs.get("cu_seq_q") + cu_seq_q = kwargs.get("cu_seq_q_list") assert(cu_seq_q is not None) + assert(type(cu_seq_q) is list) xq, xk = self._apply_rotary_per_sequence(xq, xk, freqs_cis, cu_seq_q) else: xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -313,7 +314,7 @@ def forward( xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - output = self.inner_attention(xq_packed, xk_packed, xv_packed, cu_seq_q, cu_seq_k, max_q, max_k) + output = self.inner_attention(xq_packed, xk_packed, xv_packed, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True) else: assert attention_masks is None output = self.inner_attention(xq, xk, xv) diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 7aa53dfb5c..17d635269a 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -32,7 +32,7 @@ warmup_steps = 200 # lr scheduler warm up local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 100 dataset = "c4" [parallelism] diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..f1a040e8c4 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -76,6 +76,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, + # with_stack=True, # python stack ) as torch_profiler: torch_profiler.step_num = global_step yield torch_profiler diff --git a/torchtitan/train.py b/torchtitan/train.py index 05c70685dc..5f41ad6d1f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -397,8 +397,6 @@ def batch_generator( raise DataloaderExhaustedError() from ex input_dict, labels = batch - # print(f"input_dict: {input_dict["input"]}") - # print(f"labels: {labels}") ntokens_batch = labels.numel() self.ntokens_seen += ntokens_batch self.metrics_processor.ntokens_since_last_log += ntokens_batch @@ -407,11 +405,14 @@ def batch_generator( ) # Move tensors to the appropriate device - for k, v in input_dict.items(): + for k in list(input_dict.keys()): + v = input_dict[k] + if "cu_seq" in k: + input_dict[k+"_list"] = v.tolist() if isinstance(v, torch.Tensor): input_dict[k] = v.to(device_type) - labels = labels.to(device_type) + labels = labels.to(device_type) yield input_dict, labels def post_dataloading_process( @@ -465,10 +466,12 @@ def post_dataloading_process( ) if getattr(self.model_args, "use_varlen_attn", False): - extra_kwargs["cu_seq_q"] = extra_inputs.pop("cu_seq_q", None) - extra_kwargs["cu_seq_k"] = extra_inputs.pop("cu_seq_k", None) + extra_kwargs["cu_seq_q_cpu"] = extra_inputs.pop("cu_seq_q_cpu", None) + extra_kwargs["cu_seq_k_cpu"] = extra_inputs.pop("cu_seq_k_cpu", None) extra_kwargs["max_q"] = extra_inputs.pop("max_q", None) extra_kwargs["max_k"] = extra_inputs.pop("max_k", None) + # print("extra kwargs") + # print(extra_kwargs["cu_seq_q"].device) return inputs, labels, extra_inputs, extra_kwargs @@ -530,6 +533,8 @@ def forward_backward_step( ) else: # Non-PP forward / backward + # print("before transformers") + # print(extra_kwargs.get("cu_seq_q_cpu").device) with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -557,6 +562,8 @@ def train_step( # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) + # print("in train step") + # print(input_dict["cu_seq_q"].device) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) From d8a6254f372018f5bc30f105177e32c8c28a29df Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 12 Nov 2025 13:52:20 -0800 Subject: [PATCH 04/11] correct loss/perf --- torchtitan/hf_datasets/text_datasets.py | 18 +++--------------- torchtitan/models/llama3/__init__.py | 6 +++--- torchtitan/models/llama3/model/model.py | 15 ++++++++++++++- .../llama3/train_configs/debug_model.toml | 2 +- .../models/llama3/train_configs/llama3_8b.toml | 4 ++-- torchtitan/tools/profiling.py | 2 +- 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 240bc859bf..94febc8126 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -79,7 +79,6 @@ def varlen_collate_fn(batch): Packed (input_dict, label) with collapsed batch dimension """ if len(batch) == 1: - # Single sample - already packed input_dict, label = batch[0] return { "input": input_dict["input"].unsqueeze(0), # [1, seq_len] @@ -89,7 +88,6 @@ def varlen_collate_fn(batch): "max_k": input_dict["max_k"], }, label.unsqueeze(0) # [1, seq_len] - # Multiple samples - pack them together inputs = [] labels = [] cu_seqlens_list = [] @@ -100,23 +98,17 @@ def varlen_collate_fn(batch): inputs.append(input_dict["input"]) labels.append(label) - # Get cu_seqlens from this sample and adjust by offset cu_seqlens = input_dict["cu_seq_q"] - # Don't include the last boundary (we'll add it at the end) cu_seqlens_adjusted = cu_seqlens[:-1] + offset cu_seqlens_list.append(cu_seqlens_adjusted) - # Track maximum sequence length across all samples max_seqlen = max(max_seqlen, input_dict["max_q"]) - # Update offset for next sample offset += len(input_dict["input"]) - # Concatenate all inputs and labels - packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # Shape: [total_tokens] - packed_label = torch.cat(labels, dim=0).unsqueeze(0) # Shape: [total_tokens] + packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # shape: [1, total_tokens] + packed_label = torch.cat(labels, dim=0).unsqueeze(0) # shape: [1, total_tokens] - # Combine all cu_seqlens and add final boundary packed_cu_seqlens = torch.cat( cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32)] ) @@ -189,7 +181,6 @@ def __iter__(self): # marks where this current document ends if self.use_varlen_attn: - # if self.use_varlen_attn or self.use_flex_attn: self._boundary_buffer.append(len(self._token_buffer)) while len(self._token_buffer) >= max_buffer_token_len: @@ -198,19 +189,16 @@ def __iter__(self): # update tokens to the remaining tokens self._token_buffer = self._token_buffer[max_buffer_token_len:] - input = x[:-1] # print device here + input = x[:-1] label = x[1:] if self.use_varlen_attn: - # if self.use_varlen_attn or self.use_flex_attn: boundaries_in_window = [ b for b in self._boundary_buffer if b <= max_buffer_token_len ] cu_seqlens = torch.tensor(boundaries_in_window, dtype=torch.int32) - # print device here - self._boundary_buffer = [ b - max_buffer_token_len diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 5fbcf1b139..bdf986546e 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -55,9 +55,9 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - # use_flex_attn=True, - # attn_mask_type="block_causal", - use_varlen_attn=True, + use_flex_attn=True, + attn_mask_type="block_causal", + # use_varlen_attn=True, ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 3a5adb82c8..60d7346133 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -134,8 +134,10 @@ def apply_rotary_emb( Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) @@ -282,7 +284,18 @@ def forward( cu_seq_q = kwargs.get("cu_seq_q_list") assert(cu_seq_q is not None) assert(type(cu_seq_q) is list) - xq, xk = self._apply_rotary_per_sequence(xq, xk, freqs_cis, cu_seq_q) + + true_seq_len = freqs_cis.shape[0] + total_tokens = xq.shape[1] + + true_bs = total_tokens // true_seq_len + xq = xq.view(true_bs, true_seq_len, -1, self.head_dim) + xk = xk.view(true_bs, true_seq_len, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + + xq = xq.view(1, total_tokens, -1, self.head_dim) + xk = xk.view(1, total_tokens, -1, self.head_dim) else: xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index f7dafbc042..0bd10deb24 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -19,7 +19,7 @@ enable_wandb = false [model] name = "llama3" -flavor = "debugmodel_varlen_attn" +flavor = "debugmodel_flex_attn" # flavor = "debugmodel_flex_attn" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 17d635269a..ad4643f7b3 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -6,7 +6,7 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true -save_traces_folder = "profile_trace" +save_traces_folder = "flex_profile_trace" profile_freq = 100 [metrics] @@ -32,7 +32,7 @@ warmup_steps = 200 # lr scheduler warm up local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 100 +steps = 1000 dataset = "c4" [parallelism] diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f1a040e8c4..ee9ac09984 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -76,7 +76,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, - # with_stack=True, # python stack + with_stack=True, # python stack ) as torch_profiler: torch_profiler.step_num = global_step yield torch_profiler From 9381adadf1828e4e9a58c00fda1398a482601f11 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 12 Nov 2025 14:38:02 -0800 Subject: [PATCH 05/11] cleaning up --- torchtitan/hf_datasets/text_datasets.py | 40 ++++++--- torchtitan/models/attention.py | 1 + torchtitan/models/llama3/__init__.py | 20 ++++- torchtitan/models/llama3/model/model.py | 87 ++++++------------- .../llama3/train_configs/debug_model.toml | 6 +- .../llama3/train_configs/llama3_8b.toml | 5 +- .../train_configs/llama3_8b_varlen.toml | 73 ++++++++++++++++ torchtitan/tools/profiling.py | 1 - torchtitan/train.py | 22 ++--- 9 files changed, 156 insertions(+), 99 deletions(-) create mode 100644 torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 94febc8126..8ef9e77a8d 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -18,8 +18,8 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.hf_datasets import DatasetConfig -from torchtitan.tools.logging import logger from torchtitan.protocols import train_spec +from torchtitan.tools.logging import logger def _load_c4_dataset(dataset_path: str, split: str): @@ -67,16 +67,17 @@ def _validate_dataset( logger.info(f"Preparing {dataset_name} dataset from {path}") return path, config.loader, config.sample_processor + def varlen_collate_fn(batch): """ - Custom collate function for varlen attention. - Collapses batch dimension by packing all samples into a single sequence. + Custom collate function for variable length attention + Collapses batch dimension by packing all samples into one sequence Args: batch: List of (input_dict, label) tuples Returns: - Packed (input_dict, label) with collapsed batch dimension + packed (input_dict, label) with collapsed batch dimension """ if len(batch) == 1: input_dict, label = batch[0] @@ -86,7 +87,9 @@ def varlen_collate_fn(batch): "cu_seq_k": input_dict["cu_seq_k"], "max_q": input_dict["max_q"], "max_k": input_dict["max_k"], - }, label.unsqueeze(0) # [1, seq_len] + }, label.unsqueeze( + 0 + ) # [1, seq_len] inputs = [] labels = [] @@ -179,7 +182,6 @@ def __iter__(self): self._token_buffer.extend(sample_tokens) self._sample_idx += 1 - # marks where this current document ends if self.use_varlen_attn: self._boundary_buffer.append(len(self._token_buffer)) @@ -194,11 +196,14 @@ def __iter__(self): if self.use_varlen_attn: boundaries_in_window = [ - b for b in self._boundary_buffer + b + for b in self._boundary_buffer if b <= max_buffer_token_len ] - cu_seqlens = torch.tensor(boundaries_in_window, dtype=torch.int32) + cu_seqlens = torch.tensor( + boundaries_in_window, dtype=torch.int32 + ) self._boundary_buffer = [ b - max_buffer_token_len @@ -211,10 +216,19 @@ def __iter__(self): cu_seqlens_input = cu_seqlens[cu_seqlens <= len(input)] if cu_seqlens_input[-1] != len(input): - cu_seqlens_input = torch.cat([cu_seqlens_input, torch.tensor([len(input)], dtype=torch.int32)]) + cu_seqlens_input = torch.cat( + [ + cu_seqlens_input, + torch.tensor([len(input)], dtype=torch.int32), + ] + ) seq_lengths = torch.diff(cu_seqlens_input) - max_seqlen = seq_lengths.max().item() if len(seq_lengths) > 0 else self.seq_len + max_seqlen = ( + seq_lengths.max().item() + if len(seq_lengths) > 0 + else self.seq_len + ) yield { "input": input, @@ -279,7 +293,9 @@ def build_text_dataloader( batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len - model_args = train_spec.get_train_spec(job_config.model.name).model_args[job_config.model.flavor] + model_args = train_spec.get_train_spec(job_config.model.name).model_args[ + job_config.model.flavor + ] use_varlen_attn = getattr(model_args, "use_varlen_attn", False) hf_ds = HuggingFaceTextDataset( @@ -293,7 +309,7 @@ def build_text_dataloader( ) hf_ds.use_varlen_attn = use_varlen_attn - collate_fn=varlen_collate_fn if use_varlen_attn else None + collate_fn = varlen_collate_fn if use_varlen_attn else None return ParallelAwareDataloader( dataset=hf_ds, diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index c356ee7ea0..8c73898cc2 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -34,6 +34,7 @@ "create_attention_mask", ] + class VarlenAttentionWrapper(torch.nn.Module): _compiled_varlen_attn: ClassVar[Callable] = torch.compile( varlen_attn, mode="max-autotune-no-cudagraphs" diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index bdf986546e..eee99bebbc 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -55,9 +55,27 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, + ), + "8B_flex": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, use_flex_attn=True, attn_mask_type="block_causal", - # use_varlen_attn=True, + ), + "8B_varlen": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + use_varlen_attn=True, ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 60d7346133..0dc50c19f8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -13,6 +13,8 @@ from torch import nn from torch.nn.attention.flex_attention import and_masks, BlockMask +from torch.nn.attention.varlen import varlen_attn + from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, @@ -24,8 +26,6 @@ from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol -from torch.nn.attention.varlen import varlen_attn - from .args import RoPEScalingArgs, TransformerModelArgs @@ -134,10 +134,8 @@ def apply_rotary_emb( Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) @@ -209,54 +207,12 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - def _apply_rotary_per_sequence( - self, - xq: torch.Tensor, # [bs, total_tokens, n_heads, head_dim] - xk: torch.Tensor, - freqs_cis: torch.Tensor, - cu_seqlens: list, # [num_sequences + 1] - ): - xq = xq.squeeze(0) # [total_tokens, n_heads, head_dim] - xk = xk.squeeze(0) - - xq_out_list = [] - xk_out_list = [] - - for i in range(len(cu_seqlens) - 1): - start_idx = cu_seqlens[i] - end_idx = cu_seqlens[i + 1] - seq_len = end_idx - start_idx - - # extract this sequence - xq_seq = xq[start_idx:end_idx] # [seq_len, n_heads, head_dim] - xk_seq = xk[start_idx:end_idx] - - # get freqs_cis for this sequence length (positions 0 to seq_len-1) - freqs_cis_seq = freqs_cis[:seq_len] # [seq_len, head_dim/2] - - # apply RoPE to this sequence - xq_seq_rope, xk_seq_rope = apply_rotary_emb( - xq_seq.unsqueeze(0), # add batch dim back - xk_seq.unsqueeze(0), - freqs_cis=freqs_cis_seq - ) - - xq_out_list.append(xq_seq_rope.squeeze(0)) - xk_out_list.append(xk_seq_rope.squeeze(0)) - - # concatenate all sequences back together - xq_out = torch.cat(xq_out_list, dim=0) # [total_tokens, n_heads, head_dim] - xk_out = torch.cat(xk_out_list, dim=0) - - # add batch dimension back - return xq_out.unsqueeze(0), xk_out.unsqueeze(0) - def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - **kwargs + **kwargs, ): """ Forward pass of the attention module. @@ -281,10 +237,6 @@ def forward( xv = xv.view(bs, seqlen, -1, self.head_dim) if self.use_varlen_attn: - cu_seq_q = kwargs.get("cu_seq_q_list") - assert(cu_seq_q is not None) - assert(type(cu_seq_q) is list) - true_seq_len = freqs_cis.shape[0] total_tokens = xq.shape[1] @@ -321,13 +273,26 @@ def forward( max_k = kwargs.get("max_k") n_local_heads = xq.shape[1] + xq_packed = ( + xq.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + ) + xk_packed = ( + xk.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + ) + xv_packed = ( + xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) + ) - xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - - - output = self.inner_attention(xq_packed, xk_packed, xv_packed, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True) + output = self.inner_attention( + xq_packed, + xk_packed, + xv_packed, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal=True, + ) else: assert attention_masks is None output = self.inner_attention(xq, xk, xv) @@ -427,7 +392,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - **kwargs + **kwargs, ): """ Perform a forward pass through the TransformerBlock. @@ -440,7 +405,9 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks, **kwargs) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, **kwargs + ) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -560,7 +527,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - **kwargs + **kwargs, ): """ Perform a forward pass through the Transformer model. diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 0bd10deb24..7760667edd 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -19,8 +19,7 @@ enable_wandb = false [model] name = "llama3" -flavor = "debugmodel_flex_attn" -# flavor = "debugmodel_flex_attn" +flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] @@ -79,6 +78,3 @@ enable = false dataset = "c4_validation" freq = 5 steps = 10 - -[debug] -seed = 42 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ad4643f7b3..ef86d783bf 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -6,7 +6,7 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true -save_traces_folder = "flex_profile_trace" +save_traces_folder = "profile_trace" profile_freq = 100 [metrics] @@ -68,6 +68,3 @@ enable = false dataset = "c4_validation" freq = 500 steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 - -[debug] -seed = 42 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml b/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml new file mode 100644 index 0000000000..d4415aea3a --- /dev/null +++ b/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml @@ -0,0 +1,73 @@ +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 8B training" + +[profiling] +enable_profiling = true +save_traces_folder = "varlen_profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "8B" +hf_assets_path = "./assets/hf/Llama-3.1-8B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[compile] +enable=false +components = ["model", "loss"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 + +[debug] +seed = 42 diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index ee9ac09984..f398dba9b5 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -76,7 +76,6 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, - with_stack=True, # python stack ) as torch_profiler: torch_profiler.step_num = global_step yield torch_profiler diff --git a/torchtitan/train.py b/torchtitan/train.py index 5f41ad6d1f..382dde3874 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,9 +12,9 @@ import torch -import torchtitan.protocols.train_spec as train_spec_module - from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -396,7 +396,6 @@ def batch_generator( # entire step will not be executed. raise DataloaderExhaustedError() from ex input_dict, labels = batch - ntokens_batch = labels.numel() self.ntokens_seen += ntokens_batch self.metrics_processor.ntokens_since_last_log += ntokens_batch @@ -405,14 +404,11 @@ def batch_generator( ) # Move tensors to the appropriate device - for k in list(input_dict.keys()): - v = input_dict[k] - if "cu_seq" in k: - input_dict[k+"_list"] = v.tolist() + for k, v in input_dict.items(): if isinstance(v, torch.Tensor): input_dict[k] = v.to(device_type) - labels = labels.to(device_type) + yield input_dict, labels def post_dataloading_process( @@ -466,12 +462,10 @@ def post_dataloading_process( ) if getattr(self.model_args, "use_varlen_attn", False): - extra_kwargs["cu_seq_q_cpu"] = extra_inputs.pop("cu_seq_q_cpu", None) - extra_kwargs["cu_seq_k_cpu"] = extra_inputs.pop("cu_seq_k_cpu", None) + extra_kwargs["cu_seq_q"] = extra_inputs.pop("cu_seq_q", None) + extra_kwargs["cu_seq_k"] = extra_inputs.pop("cu_seq_k", None) extra_kwargs["max_q"] = extra_inputs.pop("max_q", None) extra_kwargs["max_k"] = extra_inputs.pop("max_k", None) - # print("extra kwargs") - # print(extra_kwargs["cu_seq_q"].device) return inputs, labels, extra_inputs, extra_kwargs @@ -533,8 +527,6 @@ def forward_backward_step( ) else: # Non-PP forward / backward - # print("before transformers") - # print(extra_kwargs.get("cu_seq_q_cpu").device) with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: @@ -562,8 +554,6 @@ def train_step( # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) - # print("in train step") - # print(input_dict["cu_seq_q"].device) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) From 0012170a0708b751408e6085f970082ecd678d2b Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 14 Nov 2025 09:54:30 -0800 Subject: [PATCH 06/11] collapse batch outside of dataloader --- torchtitan/hf_datasets/text_datasets.py | 128 +----------------- torchtitan/models/attention.py | 56 +++++++- torchtitan/models/llama3/__init__.py | 2 + torchtitan/models/llama3/model/model.py | 34 ++--- .../train_configs/llama3_8b_varlen.toml | 73 ---------- torchtitan/protocols/model.py | 3 +- torchtitan/train.py | 10 +- 7 files changed, 75 insertions(+), 231 deletions(-) delete mode 100644 torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 8ef9e77a8d..493cd1abb4 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -18,7 +18,6 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.hf_datasets import DatasetConfig -from torchtitan.protocols import train_spec from torchtitan.tools.logging import logger @@ -68,63 +67,6 @@ def _validate_dataset( return path, config.loader, config.sample_processor -def varlen_collate_fn(batch): - """ - Custom collate function for variable length attention - Collapses batch dimension by packing all samples into one sequence - - Args: - batch: List of (input_dict, label) tuples - - Returns: - packed (input_dict, label) with collapsed batch dimension - """ - if len(batch) == 1: - input_dict, label = batch[0] - return { - "input": input_dict["input"].unsqueeze(0), # [1, seq_len] - "cu_seq_q": input_dict["cu_seq_q"], - "cu_seq_k": input_dict["cu_seq_k"], - "max_q": input_dict["max_q"], - "max_k": input_dict["max_k"], - }, label.unsqueeze( - 0 - ) # [1, seq_len] - - inputs = [] - labels = [] - cu_seqlens_list = [] - offset = 0 - max_seqlen = 0 - - for input_dict, label in batch: - inputs.append(input_dict["input"]) - labels.append(label) - - cu_seqlens = input_dict["cu_seq_q"] - cu_seqlens_adjusted = cu_seqlens[:-1] + offset - cu_seqlens_list.append(cu_seqlens_adjusted) - - max_seqlen = max(max_seqlen, input_dict["max_q"]) - - offset += len(input_dict["input"]) - - packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # shape: [1, total_tokens] - packed_label = torch.cat(labels, dim=0).unsqueeze(0) # shape: [1, total_tokens] - - packed_cu_seqlens = torch.cat( - cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32)] - ) - - return { - "input": packed_input, - "cu_seq_q": packed_cu_seqlens, - "cu_seq_k": packed_cu_seqlens, - "max_q": max_seqlen, - "max_k": max_seqlen, - }, packed_label - - class HuggingFaceTextDataset(IterableDataset, Stateful): def __init__( self, @@ -155,9 +97,6 @@ def __init__( self._sample_idx = 0 self._token_buffer: list[int] = [] - self._boundary_buffer: list[int] = [0] - self.use_varlen_attn: bool = False - def _get_data_iter(self): # For map-style datasets, resume by skipping to the correct index # For iterable-style datasets, the underlying iterator already points to the correct index @@ -182,63 +121,13 @@ def __iter__(self): self._token_buffer.extend(sample_tokens) self._sample_idx += 1 - if self.use_varlen_attn: - self._boundary_buffer.append(len(self._token_buffer)) - while len(self._token_buffer) >= max_buffer_token_len: x = torch.LongTensor(self._token_buffer[:max_buffer_token_len]) - # update tokens to the remaining tokens self._token_buffer = self._token_buffer[max_buffer_token_len:] - input = x[:-1] label = x[1:] - - if self.use_varlen_attn: - boundaries_in_window = [ - b - for b in self._boundary_buffer - if b <= max_buffer_token_len - ] - - cu_seqlens = torch.tensor( - boundaries_in_window, dtype=torch.int32 - ) - - self._boundary_buffer = [ - b - max_buffer_token_len - for b in self._boundary_buffer - if b > max_buffer_token_len - ] - - if not self._boundary_buffer or self._boundary_buffer[0] != 0: - self._boundary_buffer.insert(0, 0) - - cu_seqlens_input = cu_seqlens[cu_seqlens <= len(input)] - if cu_seqlens_input[-1] != len(input): - cu_seqlens_input = torch.cat( - [ - cu_seqlens_input, - torch.tensor([len(input)], dtype=torch.int32), - ] - ) - - seq_lengths = torch.diff(cu_seqlens_input) - max_seqlen = ( - seq_lengths.max().item() - if len(seq_lengths) > 0 - else self.seq_len - ) - - yield { - "input": input, - "cu_seq_q": cu_seqlens_input, - "cu_seq_k": cu_seqlens_input, - "max_q": max_seqlen, - "max_k": max_seqlen, - }, label - else: - yield {"input": input}, label + yield {"input": input}, label if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") @@ -256,7 +145,6 @@ def __iter__(self): def load_state_dict(self, state_dict): self._token_buffer = state_dict["token_buffer"] - self._boundary_buffer = state_dict.get("boundary_buffer", [0]) if isinstance(self._data, Dataset): self._sample_idx = state_dict["sample_idx"] @@ -265,10 +153,7 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict = { - "token_buffer": self._token_buffer, - "boundary_buffer": self._boundary_buffer, - } + _state_dict = {"token_buffer": self._token_buffer} if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx @@ -293,11 +178,6 @@ def build_text_dataloader( batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len - model_args = train_spec.get_train_spec(job_config.model.name).model_args[ - job_config.model.flavor - ] - use_varlen_attn = getattr(model_args, "use_varlen_attn", False) - hf_ds = HuggingFaceTextDataset( dataset_name=dataset_name, dataset_path=dataset_path, @@ -307,16 +187,12 @@ def build_text_dataloader( dp_world_size=dp_world_size, infinite=infinite, ) - hf_ds.use_varlen_attn = use_varlen_attn - - collate_fn = varlen_collate_fn if use_varlen_attn else None return ParallelAwareDataloader( dataset=hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, batch_size=batch_size, - collate_fn=collate_fn, ) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 8c73898cc2..689d44d732 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,7 +20,7 @@ flex_attention, ) -from torch.nn.attention.varlen import varlen_attn +from torch.nn.attention.varlen import varlen_attn, VarlenMetadata __all__ = [ @@ -251,3 +251,57 @@ def create_attention_mask(*args, **kwargs): arguments. """ return _compiled_create_block_mask(*args, **kwargs) + + +def create_varlen_cu_seqs(input_batch: torch.Tensor, eos_id: int) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch + eos_id: the EOS id marker + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + max_seqlen = 0 + + for b in range(batch_size): + tokens = input_batch[b] + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + eos_positions + 1, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] + ) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index eee99bebbc..63a3750b96 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -46,6 +46,7 @@ vocab_size=2048, rope_theta=500000, use_varlen_attn=True, + attn_mask_type="varlen_attn", ), "8B": TransformerModelArgs( dim=4096, @@ -76,6 +77,7 @@ multiple_of=1024, rope_theta=500000, use_varlen_attn=True, + attn_mask_type="varlen_attn", ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 0dc50c19f8..ffd1eb1ab2 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -13,11 +13,12 @@ from torch import nn from torch.nn.attention.flex_attention import and_masks, BlockMask -from torch.nn.attention.varlen import varlen_attn +from torch.nn.attention.varlen import varlen_attn, VarlenMetadata from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_varlen_cu_seqs, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, @@ -236,20 +237,7 @@ def forward( xk = xk.view(bs, seqlen, -1, self.head_dim) xv = xv.view(bs, seqlen, -1, self.head_dim) - if self.use_varlen_attn: - true_seq_len = freqs_cis.shape[0] - total_tokens = xq.shape[1] - - true_bs = total_tokens // true_seq_len - xq = xq.view(true_bs, true_seq_len, -1, self.head_dim) - xk = xk.view(true_bs, true_seq_len, -1, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis) - - xq = xq.view(1, total_tokens, -1, self.head_dim) - xk = xk.view(1, total_tokens, -1, self.head_dim) - else: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -259,18 +247,16 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - assert ( - isinstance(attention_masks, BlockMask) or attention_masks is None - ), attention_masks - if self.use_flex_attn: assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) elif self.use_varlen_attn: - cu_seq_q = kwargs.get("cu_seq_q") - cu_seq_k = kwargs.get("cu_seq_k") - max_q = kwargs.get("max_q") - max_k = kwargs.get("max_k") + assert isinstance(attention_masks, VarlenMetadata), attention_masks + + cu_seq_q = attention_masks.cu_seq_q + cu_seq_k = attention_masks.cu_seq_k + max_q = attention_masks.max_q + max_k = attention_masks.max_k n_local_heads = xq.shape[1] xq_packed = ( @@ -515,6 +501,8 @@ def get_attention_masks( case "block_causal": B = input_batch.shape[0] mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case "varlen_attn": + return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml b/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml deleted file mode 100644 index d4415aea3a..0000000000 --- a/torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml +++ /dev/null @@ -1,73 +0,0 @@ -# NOTE: this toml config is a preset for 64 A100 GPUs. - -[job] -dump_folder = "./outputs" -description = "Llama 3 8B training" - -[profiling] -enable_profiling = true -save_traces_folder = "varlen_profile_trace" -profile_freq = 100 - -[metrics] -log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" - -[model] -name = "llama3" -flavor = "8B" -hf_assets_path = "./assets/hf/Llama-3.1-8B" -# converters = ["float8"] - -[optimizer] -name = "AdamW" -lr = 3e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 200 # lr scheduler warm up - -[training] -local_batch_size = 1 -seq_len = 8192 -max_norm = 1.0 # grad norm clipping -steps = 1000 -dataset = "c4" - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -context_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 500 -last_save_model_only = true -export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] - -[compile] -enable=false -components = ["model", "loss"] - -[activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] -selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output"] - -[validation] -enable = false -dataset = "c4_validation" -freq = 500 -steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 - -[debug] -seed = 42 diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65b..2b960dae05 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -12,13 +12,14 @@ import torch.nn as nn from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.varlen import VarlenMetadata from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig -AttentionMasksType = dict[str, BlockMask] | BlockMask +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata @dataclass diff --git a/torchtitan/train.py b/torchtitan/train.py index 382dde3874..e0e616c5b7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -454,19 +454,15 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - if getattr(self.model_args, "use_flex_attn", False): + if getattr(self.model_args, "use_flex_attn", False) or getattr( + self.model_args, "use_varlen_attn", False + ): extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) - if getattr(self.model_args, "use_varlen_attn", False): - extra_kwargs["cu_seq_q"] = extra_inputs.pop("cu_seq_q", None) - extra_kwargs["cu_seq_k"] = extra_inputs.pop("cu_seq_k", None) - extra_kwargs["max_q"] = extra_inputs.pop("max_q", None) - extra_kwargs["max_k"] = extra_inputs.pop("max_k", None) - return inputs, labels, extra_inputs, extra_kwargs def forward_backward_step( From 4fef6ebd8ac74682a648c8446009de05a991d7b9 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Mon, 17 Nov 2025 07:47:22 -0800 Subject: [PATCH 07/11] remove explicit mask def --- torchtitan/models/attention.py | 17 +++++++++++++++-- torchtitan/models/llama3/__init__.py | 2 -- torchtitan/models/llama3/model/model.py | 11 ++++------- torchtitan/protocols/model.py | 2 +- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 689d44d732..73f79abfd5 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -8,7 +8,7 @@ import functools from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, NamedTuple import torch import torch.nn.functional as F @@ -20,13 +20,14 @@ flex_attention, ) -from torch.nn.attention.varlen import varlen_attn, VarlenMetadata +from torch.nn.attention.varlen import varlen_attn __all__ = [ "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", "VarlenAttentionWrapper", + "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", "get_sliding_window_mask_mod", @@ -35,6 +36,18 @@ ] +class VarlenMetadata(NamedTuple): + """ + Cumulative sequence positions for queries and keys/values. + + """ + + cu_seq_q: torch.Tensor + cu_seq_k: torch.Tensor + max_q: int + max_k: int + + class VarlenAttentionWrapper(torch.nn.Module): _compiled_varlen_attn: ClassVar[Callable] = torch.compile( varlen_attn, mode="max-autotune-no-cudagraphs" diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 63a3750b96..eee99bebbc 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -46,7 +46,6 @@ vocab_size=2048, rope_theta=500000, use_varlen_attn=True, - attn_mask_type="varlen_attn", ), "8B": TransformerModelArgs( dim=4096, @@ -77,7 +76,6 @@ multiple_of=1024, rope_theta=500000, use_varlen_attn=True, - attn_mask_type="varlen_attn", ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index ffd1eb1ab2..48991343bc 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -13,8 +13,6 @@ from torch import nn from torch.nn.attention.flex_attention import and_masks, BlockMask -from torch.nn.attention.varlen import varlen_attn, VarlenMetadata - from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, @@ -23,6 +21,7 @@ get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, ) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -199,7 +198,7 @@ def __init__(self, model_args: TransformerModelArgs): if self.use_flex_attn: self.inner_attention = FlexAttentionWrapper() elif self.use_varlen_attn: - self.inner_attention = varlen_attn + self.inner_attention = VarlenAttentionWrapper() else: self.inner_attention = ScaledDotProductAttentionWrapper() @@ -251,8 +250,6 @@ def forward( assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) elif self.use_varlen_attn: - assert isinstance(attention_masks, VarlenMetadata), attention_masks - cu_seq_q = attention_masks.cu_seq_q cu_seq_k = attention_masks.cu_seq_k max_q = attention_masks.max_q @@ -495,14 +492,14 @@ def get_attention_masks( extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] + if self.model_args.use_varlen_attn: + return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) match self.model_args.attn_mask_type: case "causal": B = 1 case "block_causal": B = input_batch.shape[0] mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) - case "varlen_attn": - return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 2b960dae05..4cb193c31a 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -12,11 +12,11 @@ import torch.nn as nn from torch.nn.attention.flex_attention import BlockMask -from torch.nn.attention.varlen import VarlenMetadata from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig +from torchtitan.models.attention import VarlenMetadata AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata From 0d32d5af19106dd13fbc008abe09bc20450c7306 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Tue, 18 Nov 2025 08:53:15 -0800 Subject: [PATCH 08/11] attention_type --- torchtitan/models/attention.py | 35 +++++--- torchtitan/models/llama3/__init__.py | 8 +- torchtitan/models/llama3/infra/parallelize.py | 3 +- torchtitan/models/llama3/model/args.py | 11 ++- torchtitan/models/llama3/model/model.py | 84 ++++++++++--------- torchtitan/protocols/model.py | 18 ++++ torchtitan/train.py | 10 ++- 7 files changed, 108 insertions(+), 61 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 73f79abfd5..11d3a35418 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -55,18 +55,32 @@ class VarlenAttentionWrapper(torch.nn.Module): def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seq_q: torch.Tensor, - cu_seq_k: torch.Tensor, - max_q: int, - max_k: int, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + head_dim: torch.Tensor, + attention_masks: VarlenMetadata, is_causal: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + cu_seq_q = attention_masks.cu_seq_q + cu_seq_k = attention_masks.cu_seq_k + max_q = attention_masks.max_q + max_k = attention_masks.max_k + + n_local_heads = xq.shape[1] + xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) + xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) + xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) return VarlenAttentionWrapper._compiled_varlen_attn( - q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True + xq_packed, + xk_packed, + xv_packed, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal=True, ) @@ -104,7 +118,6 @@ def forward( # `FlexAttentionWrapper._compiled_flex_attn` is correct. # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation # to convert `lse` to be DTensor. - return FlexAttentionWrapper._compiled_flex_attn( q, k, @@ -266,7 +279,9 @@ def create_attention_mask(*args, **kwargs): return _compiled_create_block_mask(*args, **kwargs) -def create_varlen_cu_seqs(input_batch: torch.Tensor, eos_id: int) -> VarlenMetadata: +def create_varlen_metadata_for_document( + input_batch: torch.Tensor, eos_id: int +) -> VarlenMetadata: """ Creates cumulative sequence length indices needed for variable length attention diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index eee99bebbc..aab2f7fb61 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -36,7 +36,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - use_flex_attn=True, + attention_type="flex", attn_mask_type="block_causal", ), "debugmodel_varlen_attn": TransformerModelArgs( @@ -45,7 +45,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - use_varlen_attn=True, + attention_type="varlen", ), "8B": TransformerModelArgs( dim=4096, @@ -64,7 +64,7 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - use_flex_attn=True, + attention_type="flex", attn_mask_type="block_causal", ), "8B_varlen": TransformerModelArgs( @@ -75,7 +75,7 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - use_varlen_attn=True, + attention_type="varlen", ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 86ac3a6dfe..b0c4cedc98 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -67,7 +67,8 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + attn_type = getattr(model.model_args, "attention_type", False) + use_flex_attn = attn_type == "flex" if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 2bbe862837..84f47fb58c 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -9,8 +9,9 @@ from dataclasses import dataclass, field -from torch import nn +from typing import Literal +from torch import nn from torchtitan.config import JobConfig from torchtitan.models.utils import get_dense_model_nparams_and_flops from torchtitan.protocols.model import BaseModelArgs @@ -43,8 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False - use_varlen_attn: bool = False + attention_type: Literal["flex", "varlen"] = None attn_mask_type: str = "causal" eos_id: int = 0 @@ -56,7 +56,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attention_type == "flex" + ): raise NotImplementedError( "CP support for FlexAttention is still in progress." ) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 48991343bc..a9690e06ea 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -16,12 +16,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, - create_varlen_cu_seqs, + create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -193,8 +194,8 @@ def __init__(self, model_args: TransformerModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn - self.use_varlen_attn = model_args.use_varlen_attn + self.use_flex_attn = model_args.attention_type == "flex" + self.use_varlen_attn = model_args.attention_type == "varlen" if self.use_flex_attn: self.inner_attention = FlexAttentionWrapper() elif self.use_varlen_attn: @@ -212,7 +213,6 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - **kwargs, ): """ Forward pass of the attention module. @@ -250,30 +250,13 @@ def forward( assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) elif self.use_varlen_attn: - cu_seq_q = attention_masks.cu_seq_q - cu_seq_k = attention_masks.cu_seq_k - max_q = attention_masks.max_q - max_k = attention_masks.max_k - - n_local_heads = xq.shape[1] - xq_packed = ( - xq.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - ) - xk_packed = ( - xk.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - ) - xv_packed = ( - xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim) - ) - + assert isinstance(attention_masks, VarlenMetadata), attention_masks output = self.inner_attention( - xq_packed, - xk_packed, - xv_packed, - cu_seq_q, - cu_seq_k, - max_q, - max_k, + xq, + xk, + xv, + self.head_dim, + attention_masks, is_causal=True, ) else: @@ -375,7 +358,6 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - **kwargs, ): """ Perform a forward pass through the TransformerBlock. @@ -388,9 +370,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention( - self.attention_norm(x), freqs_cis, attention_masks, **kwargs - ) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -485,34 +465,61 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, - tokenizer: BaseTokenizer, + eos_id: int, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] - if self.model_args.use_varlen_attn: - return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) + match self.model_args.attn_mask_type: case "causal": B = 1 case "block_causal": B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + mask_mods.append(get_document_mask_mod(input_batch, eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) + return create_attention_mask( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def _get_varlen_attention_masks( + self, + input_batch: torch.Tensor, + eos_id: int, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + return create_varlen_metadata_for_document(input_batch, eos_id) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attention_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer.eos_id, extra_inputs + ) + case "varlen": + return self._get_varlen_attention_masks( + input_batch, tokenizer.eos_id, extra_inputs + ) + case _: + raise NotImplementedError( + "Only varlen and flex attn masks are supported" + ) + def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - **kwargs, ): """ Perform a forward pass through the Transformer model. @@ -531,8 +538,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks, **kwargs) - + h = layer(h, self.freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 4cb193c31a..c8c0bad2c7 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -71,3 +71,21 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) + + def _get_varlen_attention_masks( + self, + input_batch: torch.Tensor, + eos_id: int, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + raise NotImplementedError( + "This model does not support variable length attention." + ) + + def _get_flex_attention_masks( + self, + input_batch: torch.Tensor, + eos_id: int, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + raise NotImplementedError("This model does not support flex attention.") diff --git a/torchtitan/train.py b/torchtitan/train.py index e0e616c5b7..fd37dee8d0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -454,9 +454,13 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - if getattr(self.model_args, "use_flex_attn", False) or getattr( - self.model_args, "use_varlen_attn", False - ): + attn_type = getattr(self.model_args, "attention_type", False) + use_varlen_attn = attn_type == "varlen" + use_flex_attn = ( + getattr(self.model_args, "use_flex_attn", False) or attn_type == "flex" + ) + + if use_flex_attn or use_varlen_attn: extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, From 4d80f4e9c0fadef1495783ce69f293a5f8346120 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 19 Nov 2025 07:34:07 -0800 Subject: [PATCH 09/11] remove use_flex_attn --- .../distributed/activation_checkpoint.py | 10 ++-- torchtitan/models/llama3/infra/parallelize.py | 8 +-- torchtitan/models/llama3/model/args.py | 4 +- torchtitan/models/llama3/model/model.py | 55 ++++++++++--------- torchtitan/train.py | 7 +-- 5 files changed, 38 insertions(+), 46 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 8359f71730..a6936ac871 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -236,7 +236,7 @@ def _apply_ac_to_transformer_block( *, base_fqn: str | None = None, model_compile_enabled: bool = False, - use_flex_attn: bool = False, + attn_type: str = "sdpa", op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") @@ -259,7 +259,7 @@ def _apply_ac_to_transformer_block( if use_op_sac: op_sac_save_list = op_sac_save_list or set() - if use_flex_attn: + if attn_type == "flex": """ For Flex Attention, we need to apply SAC carefully to avoid invalidating torch.compile. Any torch.compile inside the SAC region will be ignored, @@ -288,7 +288,7 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - use_flex_attn: bool = False, + attn_type: str = "sdpa", op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: @@ -302,7 +302,7 @@ def apply_ac( model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - use_flex_attn (bool): Whether flex attention is enabled for the model. + attn_type (str): Attention type (one of [sdpa, varlen, flex]) op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. Returns: @@ -326,7 +326,7 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + attn_type=attn_type, op_sac_save_list=op_sac_save_list, ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b0c4cedc98..e99c62633d 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -67,11 +67,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - attn_type = getattr(model.model_args, "attention_type", False) - use_flex_attn = attn_type == "flex" - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -96,12 +91,13 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attention_type", False) if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + attn_type=attn_type, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 84f47fb58c..7f1aa41e71 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - attention_type: Literal["flex", "varlen"] = None + attention_type: Literal["flex", "varlen"] = "sdpa" attn_mask_type: str = "causal" eos_id: int = 0 @@ -58,7 +58,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: if ( job_config.parallelism.context_parallel_degree > 1 - and self.attention_type == "flex" + and self.attention_type != "sdpa" ): raise NotImplementedError( "CP support for FlexAttention is still in progress." diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index a9690e06ea..4b861d6181 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -194,14 +194,14 @@ def __init__(self, model_args: TransformerModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.attention_type == "flex" - self.use_varlen_attn = model_args.attention_type == "varlen" - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - elif self.use_varlen_attn: - self.inner_attention = VarlenAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attention_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "varlen": + self.inner_attention = VarlenAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -246,22 +246,23 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - elif self.use_varlen_attn: - assert isinstance(attention_masks, VarlenMetadata), attention_masks - output = self.inner_attention( - xq, - xk, - xv, - self.head_dim, - attention_masks, - is_causal=True, - ) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case "varlen": + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + is_causal=True, + ) + case _: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -468,7 +469,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor: def _get_flex_attention_masks( self, input_batch: torch.Tensor, - eos_id: int, + tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] @@ -478,7 +479,7 @@ def _get_flex_attention_masks( B = 1 case "block_causal": B = input_batch.shape[0] - mask_mods.append(get_document_mask_mod(input_batch, eos_id)) + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -505,7 +506,7 @@ def get_attention_masks( match self.model_args.attention_type: case "flex": return self._get_flex_attention_masks( - input_batch, tokenizer.eos_id, extra_inputs + input_batch, tokenizer, extra_inputs ) case "varlen": return self._get_varlen_attention_masks( diff --git a/torchtitan/train.py b/torchtitan/train.py index fd37dee8d0..17aaaac10d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -455,12 +455,7 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} attn_type = getattr(self.model_args, "attention_type", False) - use_varlen_attn = attn_type == "varlen" - use_flex_attn = ( - getattr(self.model_args, "use_flex_attn", False) or attn_type == "flex" - ) - - if use_flex_attn or use_varlen_attn: + if attn_type in ["flex", "varlen"]: extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, From ab033ddedcd503eee33b6de3dc0a5092ab101523 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 19 Nov 2025 09:47:20 -0800 Subject: [PATCH 10/11] remove use_flex for all other models --- .../distributed/activation_checkpoint.py | 10 +++---- torchtitan/experiments/forge/example_train.py | 2 +- .../experiments/gpt_oss/infra/parallelize.py | 6 ++-- torchtitan/experiments/gpt_oss/model/args.py | 4 +-- .../simple_fsdp/deepseek_v3/parallelize.py | 6 ++-- .../simple_fsdp/llama3/parallelize.py | 3 +- .../experiments/vlm/infra/parallelize.py | 7 +++-- torchtitan/experiments/vlm/model/args.py | 2 +- torchtitan/models/attention.py | 8 +++--- torchtitan/models/deepseek_v3/__init__.py | 8 +++--- .../models/deepseek_v3/infra/parallelize.py | 26 ++++++----------- torchtitan/models/deepseek_v3/model/args.py | 10 +++---- torchtitan/models/deepseek_v3/model/model.py | 28 ++++++++++--------- torchtitan/models/llama3/__init__.py | 10 ++++--- torchtitan/models/llama3/infra/parallelize.py | 5 ++-- torchtitan/models/llama3/model/args.py | 6 ++-- torchtitan/models/llama3/model/model.py | 26 ++++++++--------- torchtitan/models/llama4/__init__.py | 6 ++-- torchtitan/models/llama4/infra/parallelize.py | 6 ++-- torchtitan/models/llama4/model/args.py | 11 ++++---- torchtitan/models/llama4/model/model.py | 15 +++++----- torchtitan/models/qwen3/infra/parallelize.py | 8 +++--- torchtitan/models/qwen3/model/args.py | 2 +- torchtitan/models/qwen3/model/model.py | 24 ++++++++-------- torchtitan/protocols/model.py | 18 ------------ torchtitan/train.py | 2 +- 26 files changed, 116 insertions(+), 143 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index a6936ac871..8359f71730 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -236,7 +236,7 @@ def _apply_ac_to_transformer_block( *, base_fqn: str | None = None, model_compile_enabled: bool = False, - attn_type: str = "sdpa", + use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") @@ -259,7 +259,7 @@ def _apply_ac_to_transformer_block( if use_op_sac: op_sac_save_list = op_sac_save_list or set() - if attn_type == "flex": + if use_flex_attn: """ For Flex Attention, we need to apply SAC carefully to avoid invalidating torch.compile. Any torch.compile inside the SAC region will be ignored, @@ -288,7 +288,7 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - attn_type: str = "sdpa", + use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: @@ -302,7 +302,7 @@ def apply_ac( model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - attn_type (str): Attention type (one of [sdpa, varlen, flex]) + use_flex_attn (bool): Whether flex attention is enabled for the model. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. Returns: @@ -326,7 +326,7 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - attn_type=attn_type, + use_flex_attn=use_flex_attn, op_sac_save_list=op_sac_save_list, ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7b0b0c81e9..66ad151dd0 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -161,7 +161,7 @@ def forward_backward_step( inputs = input_dict["input"] extra_kwargs = {} - if getattr(self.model_args, "use_flex_attn", False): + if getattr(self.model_args, "attn_type", "sdpa") == "flex": extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 7714d497e4..1070f58aad 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -62,10 +62,6 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -111,6 +107,8 @@ def parallelize_gptoss( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index e78eac4d74..af4c51eadc 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads (int): Number of key-value heads. sliding_window_size (int): Size of the sliding attention window. attn_mask_type (str): Type of basic attention mask. - use_flex_attn (bool): Whether to use FlexAttention. Only supports True. + attn_type (bool): Attention type, only supports Flex. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. @@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads: int = 8 sliding_window_size: int = 128 attn_mask_type: str = "causal" - use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention + attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention # yarn original_seq_len: int = 4096 rope_theta: float = 150000.0 diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 6d415004cc..83e24d7dc1 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -67,9 +67,9 @@ def parallelize_deepseekv3( if ( job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn + and model.model_args.attn_type != "sdpa" ): - raise NotImplementedError("CP support for FlexAttention is still in progress.") + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -85,13 +85,11 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_non_moe_tp( model, world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 67a012a3f7..fb07ef617a 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -102,7 +102,8 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 6a97e4ece1..d418ad6edd 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -48,9 +48,9 @@ def parallelize_vlm( Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: raise NotImplementedError("TP support for VLM training is still in progress.") @@ -58,6 +58,7 @@ def parallelize_vlm( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py index 11b6439ddd..49ba31246b 100644 --- a/torchtitan/experiments/vlm/model/args.py +++ b/torchtitan/experiments/vlm/model/args.py @@ -53,7 +53,7 @@ class Siglip2ModelArgs: spatial_merge_size: int = 1 layer_norm_eps: float = 1e-6 - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "causal" diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 11d3a35418..cc7b87cb20 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -60,7 +60,6 @@ def forward( xv: torch.Tensor, head_dim: torch.Tensor, attention_masks: VarlenMetadata, - is_causal: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: cu_seq_q = attention_masks.cu_seq_q cu_seq_k = attention_masks.cu_seq_k @@ -68,9 +67,9 @@ def forward( max_k = attention_masks.max_k n_local_heads = xq.shape[1] - xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) - xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) - xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim) + xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) return VarlenAttentionWrapper._compiled_varlen_attn( xq_packed, @@ -325,6 +324,7 @@ def create_varlen_metadata_for_document( max_seqlen = 0 if len(all_seq_lengths) > 0: all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward max_seqlen = all_seq_lengths.max().item() return VarlenMetadata( diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..7e2d35a5d9 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -72,7 +72,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( @@ -124,7 +124,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( @@ -151,7 +151,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..a7e1ee0dc5 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -61,9 +61,10 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -84,7 +85,6 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -181,7 +181,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -211,18 +210,11 @@ def apply_non_moe_tp( PrepareModuleInput, ) - if use_flex_attn: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) - else: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..e683905878 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -44,7 +44,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. - use_flex_attn (bool): Whether to use FlexAttention. + attn_type (str): Attention type. attn_mask_type (str): Type of attention mask. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. @@ -76,7 +76,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # yarn @@ -101,10 +101,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..7d7635a4ad 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -184,11 +184,12 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, @@ -245,14 +246,15 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask) - output = self.inner_attention( - q, k, v, block_mask=attention_masks, scale=self.softmax_scale - ) - else: - assert attention_masks is None - output = self.inner_attention(q, k, v, scale=self.softmax_scale) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + case _: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index aab2f7fb61..75ab234ebc 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -36,7 +36,7 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - attention_type="flex", + attn_type="flex", attn_mask_type="block_causal", ), "debugmodel_varlen_attn": TransformerModelArgs( @@ -45,7 +45,8 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - attention_type="varlen", + attn_type="varlen", + attn_mask_type="block_causal", ), "8B": TransformerModelArgs( dim=4096, @@ -64,7 +65,7 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - attention_type="flex", + attn_type="flex", attn_mask_type="block_causal", ), "8B_varlen": TransformerModelArgs( @@ -75,7 +76,8 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - attention_type="varlen", + attn_type="varlen", + attn_mask_type="block_causal", ), "70B": TransformerModelArgs( dim=8192, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index e99c62633d..b517e5c15f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -91,13 +91,14 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attention_type", False) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - attn_type=attn_type, + use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 7f1aa41e71..81680074eb 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -9,8 +9,6 @@ from dataclasses import dataclass, field -from typing import Literal - from torch import nn from torchtitan.config import JobConfig from torchtitan.models.utils import get_dense_model_nparams_and_flops @@ -44,7 +42,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - attention_type: Literal["flex", "varlen"] = "sdpa" + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 0 @@ -58,7 +56,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: if ( job_config.parallelism.context_parallel_degree > 1 - and self.attention_type != "sdpa" + and self.attn_type != "sdpa" ): raise NotImplementedError( "CP support for FlexAttention is still in progress." diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 4b861d6181..74b862bf76 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -194,7 +194,7 @@ def __init__(self, model_args: TransformerModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.attn_type = model_args.attention_type + self.attn_type = model_args.attn_type match self.attn_type: case "flex": self.inner_attention = FlexAttentionWrapper() @@ -258,11 +258,12 @@ def forward( xv, self.head_dim, attention_masks, - is_causal=True, ) - case _: + case "sdpa": assert attention_masks is None output = self.inner_attention(xq, xk, xv) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.transpose( 1, 2 @@ -489,28 +490,25 @@ def _get_flex_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) - def _get_varlen_attention_masks( - self, - input_batch: torch.Tensor, - eos_id: int, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType: - return create_varlen_metadata_for_document(input_batch, eos_id) - def get_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: - match self.model_args.attention_type: + match self.model_args.attn_type: case "flex": return self._get_flex_attention_masks( input_batch, tokenizer, extra_inputs ) case "varlen": - return self._get_varlen_attention_masks( - input_batch, tokenizer.eos_id, extra_inputs + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id ) case _: raise NotImplementedError( diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..b8bd9a4484 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -67,7 +67,7 @@ rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx16e_irope": TransformerModelArgs( @@ -83,7 +83,7 @@ moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx128e_irope": TransformerModelArgs( @@ -96,7 +96,7 @@ rope_theta=500000, moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9911ecdfd0..01ce9d543b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -75,10 +75,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -117,6 +113,8 @@ def parallelize_llama( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 7fcc9871f5..a277ca382e 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is @@ -76,10 +76,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index c8241b84de..6b9d2d2d9e 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -202,11 +202,12 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -217,7 +218,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - attention_masks: AttentionMasksType | None, + attention_masks: AttentionMasksType, ): """ Forward pass of the attention module. @@ -252,7 +253,7 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: + if self.attn_type == "flex": assert isinstance(attention_masks, dict), attention_masks attention_mask = attention_masks["rope" if self.use_rope else "nope"] output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..74254081b6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -59,9 +59,9 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -112,7 +112,7 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=attn_type == "flex", op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..2def3a949a 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,7 +36,7 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 151645 diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a4f0a59844..89296ed98d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -143,7 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 - self.use_flex_attn = getattr(model_args, "use_flex_attn", False) + self.attn_type = getattr(model_args, "attn_type", "sdpa") # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -167,10 +167,11 @@ def __init__(self, model_args: Qwen3ModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -226,12 +227,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case _: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index c8c0bad2c7..4cb193c31a 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -71,21 +71,3 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) - - def _get_varlen_attention_masks( - self, - input_batch: torch.Tensor, - eos_id: int, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType: - raise NotImplementedError( - "This model does not support variable length attention." - ) - - def _get_flex_attention_masks( - self, - input_batch: torch.Tensor, - eos_id: int, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType: - raise NotImplementedError("This model does not support flex attention.") diff --git a/torchtitan/train.py b/torchtitan/train.py index 17aaaac10d..4d2f13d6da 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -454,7 +454,7 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - attn_type = getattr(self.model_args, "attention_type", False) + attn_type = getattr(self.model_args, "attn_type", "sdpa") if attn_type in ["flex", "varlen"]: extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, From 2b1a40f8089b2f59464485401222349a733ee548 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Fri, 21 Nov 2025 08:26:23 -0800 Subject: [PATCH 11/11] integration test --- tests/integration_tests/features.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 8bf3a0249f..6fafa29871 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -346,6 +346,18 @@ def build_features_test_list() -> list[OverrideDefinitions]: "fsdp+flex_attn+per_op_sac", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=4", + "--activation_checkpoint.mode='full'", + "--model.flavor=debugmodel_varlen_attn", + ] + ], + "FSDP+VARLEN_ATTN", + "fsdp+varlen_attn", + ngpu=4, + ), OverrideDefinitions( [ [