Skip to content

Commit 55352a5

Browse files
committed
collapse batch outside of dataloader
1 parent cad97e5 commit 55352a5

File tree

7 files changed

+77
-159
lines changed

7 files changed

+77
-159
lines changed

torchtitan/hf_datasets/text_datasets.py

Lines changed: 2 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torchtitan.components.tokenizer import BaseTokenizer
1919
from torchtitan.config import JobConfig
2020
from torchtitan.hf_datasets import DatasetConfig
21-
from torchtitan.protocols import train_spec
2221
from torchtitan.tools.logging import logger
2322

2423

@@ -68,63 +67,6 @@ def _validate_dataset(
6867
return path, config.loader, config.sample_processor
6968

7069

71-
def varlen_collate_fn(batch):
72-
"""
73-
Custom collate function for variable length attention
74-
Collapses batch dimension by packing all samples into one sequence
75-
76-
Args:
77-
batch: List of (input_dict, label) tuples
78-
79-
Returns:
80-
packed (input_dict, label) with collapsed batch dimension
81-
"""
82-
if len(batch) == 1:
83-
input_dict, label = batch[0]
84-
return {
85-
"input": input_dict["input"].unsqueeze(0), # [1, seq_len]
86-
"cu_seq_q": input_dict["cu_seq_q"],
87-
"cu_seq_k": input_dict["cu_seq_k"],
88-
"max_q": input_dict["max_q"],
89-
"max_k": input_dict["max_k"],
90-
}, label.unsqueeze(
91-
0
92-
) # [1, seq_len]
93-
94-
inputs = []
95-
labels = []
96-
cu_seqlens_list = []
97-
offset = 0
98-
max_seqlen = 0
99-
100-
for input_dict, label in batch:
101-
inputs.append(input_dict["input"])
102-
labels.append(label)
103-
104-
cu_seqlens = input_dict["cu_seq_q"]
105-
cu_seqlens_adjusted = cu_seqlens[:-1] + offset
106-
cu_seqlens_list.append(cu_seqlens_adjusted)
107-
108-
max_seqlen = max(max_seqlen, input_dict["max_q"])
109-
110-
offset += len(input_dict["input"])
111-
112-
packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # shape: [1, total_tokens]
113-
packed_label = torch.cat(labels, dim=0).unsqueeze(0) # shape: [1, total_tokens]
114-
115-
packed_cu_seqlens = torch.cat(
116-
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32)]
117-
)
118-
119-
return {
120-
"input": packed_input,
121-
"cu_seq_q": packed_cu_seqlens,
122-
"cu_seq_k": packed_cu_seqlens,
123-
"max_q": max_seqlen,
124-
"max_k": max_seqlen,
125-
}, packed_label
126-
127-
12870
class HuggingFaceTextDataset(IterableDataset, Stateful):
12971
def __init__(
13072
self,
@@ -155,9 +97,6 @@ def __init__(
15597
self._sample_idx = 0
15698
self._token_buffer: list[int] = []
15799

158-
self._boundary_buffer: list[int] = [0]
159-
self.use_varlen_attn: bool = False
160-
161100
def _get_data_iter(self):
162101
# For map-style datasets, resume by skipping to the correct index
163102
# For iterable-style datasets, the underlying iterator already points to the correct index
@@ -182,63 +121,13 @@ def __iter__(self):
182121
self._token_buffer.extend(sample_tokens)
183122
self._sample_idx += 1
184123

185-
if self.use_varlen_attn:
186-
self._boundary_buffer.append(len(self._token_buffer))
187-
188124
while len(self._token_buffer) >= max_buffer_token_len:
189125
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
190-
191126
# update tokens to the remaining tokens
192127
self._token_buffer = self._token_buffer[max_buffer_token_len:]
193-
194128
input = x[:-1]
195129
label = x[1:]
196-
197-
if self.use_varlen_attn:
198-
boundaries_in_window = [
199-
b
200-
for b in self._boundary_buffer
201-
if b <= max_buffer_token_len
202-
]
203-
204-
cu_seqlens = torch.tensor(
205-
boundaries_in_window, dtype=torch.int32
206-
)
207-
208-
self._boundary_buffer = [
209-
b - max_buffer_token_len
210-
for b in self._boundary_buffer
211-
if b > max_buffer_token_len
212-
]
213-
214-
if not self._boundary_buffer or self._boundary_buffer[0] != 0:
215-
self._boundary_buffer.insert(0, 0)
216-
217-
cu_seqlens_input = cu_seqlens[cu_seqlens <= len(input)]
218-
if cu_seqlens_input[-1] != len(input):
219-
cu_seqlens_input = torch.cat(
220-
[
221-
cu_seqlens_input,
222-
torch.tensor([len(input)], dtype=torch.int32),
223-
]
224-
)
225-
226-
seq_lengths = torch.diff(cu_seqlens_input)
227-
max_seqlen = (
228-
seq_lengths.max().item()
229-
if len(seq_lengths) > 0
230-
else self.seq_len
231-
)
232-
233-
yield {
234-
"input": input,
235-
"cu_seq_q": cu_seqlens_input,
236-
"cu_seq_k": cu_seqlens_input,
237-
"max_q": max_seqlen,
238-
"max_k": max_seqlen,
239-
}, label
240-
else:
241-
yield {"input": input}, label
130+
yield {"input": input}, label
242131

243132
if not self.infinite:
244133
logger.warning(f"Dataset {self.dataset_name} has run out of data")
@@ -256,7 +145,6 @@ def __iter__(self):
256145

257146
def load_state_dict(self, state_dict):
258147
self._token_buffer = state_dict["token_buffer"]
259-
self._boundary_buffer = state_dict.get("boundary_buffer", [0])
260148

261149
if isinstance(self._data, Dataset):
262150
self._sample_idx = state_dict["sample_idx"]
@@ -265,10 +153,7 @@ def load_state_dict(self, state_dict):
265153
self._data.load_state_dict(state_dict["data"])
266154

267155
def state_dict(self):
268-
_state_dict = {
269-
"token_buffer": self._token_buffer,
270-
"boundary_buffer": self._boundary_buffer,
271-
}
156+
_state_dict = {"token_buffer": self._token_buffer}
272157

273158
if isinstance(self._data, Dataset):
274159
_state_dict["sample_idx"] = self._sample_idx
@@ -293,11 +178,6 @@ def build_text_dataloader(
293178
batch_size = job_config.training.local_batch_size
294179
seq_len = job_config.training.seq_len
295180

296-
model_args = train_spec.get_train_spec(job_config.model.name).model_args[
297-
job_config.model.flavor
298-
]
299-
use_varlen_attn = getattr(model_args, "use_varlen_attn", False)
300-
301181
hf_ds = HuggingFaceTextDataset(
302182
dataset_name=dataset_name,
303183
dataset_path=dataset_path,
@@ -307,16 +187,12 @@ def build_text_dataloader(
307187
dp_world_size=dp_world_size,
308188
infinite=infinite,
309189
)
310-
hf_ds.use_varlen_attn = use_varlen_attn
311-
312-
collate_fn = varlen_collate_fn if use_varlen_attn else None
313190

314191
return ParallelAwareDataloader(
315192
dataset=hf_ds,
316193
dp_rank=dp_rank,
317194
dp_world_size=dp_world_size,
318195
batch_size=batch_size,
319-
collate_fn=collate_fn,
320196
)
321197

322198

torchtitan/models/attention.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
flex_attention,
2121
)
2222

23-
from torch.nn.attention.varlen import varlen_attn
23+
from torch.nn.attention.varlen import varlen_attn, VarlenMetadata
2424

2525

2626
__all__ = [
@@ -251,3 +251,57 @@ def create_attention_mask(*args, **kwargs):
251251
arguments.
252252
"""
253253
return _compiled_create_block_mask(*args, **kwargs)
254+
255+
256+
def create_varlen_cu_seqs(input_batch: torch.Tensor, eos_id: int) -> VarlenMetadata:
257+
"""
258+
Creates cumulative sequence length indices needed for variable length attention
259+
260+
Args:
261+
input_batch
262+
eos_id: the EOS id marker
263+
264+
Returns:
265+
VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len
266+
"""
267+
batch_size, seq_len = input_batch.shape
268+
device = input_batch.device
269+
cu_seqlens_list, all_seq_lengths = [], []
270+
offset = 0
271+
max_seqlen = 0
272+
273+
for b in range(batch_size):
274+
tokens = input_batch[b]
275+
eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)
276+
sample_cu_seqlens = torch.cat(
277+
[
278+
torch.tensor([0], dtype=torch.int32, device=device),
279+
eos_positions + 1,
280+
torch.tensor([seq_len], dtype=torch.int32, device=device),
281+
]
282+
)
283+
sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)
284+
285+
seq_lengths = torch.diff(sample_cu_seqlens)
286+
all_seq_lengths.append(seq_lengths)
287+
288+
cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
289+
cu_seqlens_list.append(cu_seqlens_adjusted)
290+
291+
offset += seq_len
292+
293+
packed_cu_seqlens = torch.cat(
294+
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]
295+
)
296+
297+
max_seqlen = 0
298+
if len(all_seq_lengths) > 0:
299+
all_seq_lengths = torch.cat(all_seq_lengths)
300+
max_seqlen = all_seq_lengths.max().item()
301+
302+
return VarlenMetadata(
303+
cu_seq_q=packed_cu_seqlens,
304+
cu_seq_k=packed_cu_seqlens,
305+
max_q=max_seqlen,
306+
max_k=max_seqlen,
307+
)

torchtitan/models/llama3/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
vocab_size=2048,
4747
rope_theta=500000,
4848
use_varlen_attn=True,
49+
attn_mask_type="varlen_attn",
4950
),
5051
"8B": TransformerModelArgs(
5152
dim=4096,
@@ -76,6 +77,7 @@
7677
multiple_of=1024,
7778
rope_theta=500000,
7879
use_varlen_attn=True,
80+
attn_mask_type="varlen_attn",
7981
),
8082
"70B": TransformerModelArgs(
8183
dim=8192,

torchtitan/models/llama3/model/model.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from torch import nn
1414
from torch.nn.attention.flex_attention import and_masks, BlockMask
1515

16-
from torch.nn.attention.varlen import varlen_attn
16+
from torch.nn.attention.varlen import varlen_attn, VarlenMetadata
1717

1818
from torchtitan.components.tokenizer import BaseTokenizer
1919
from torchtitan.models.attention import (
2020
create_attention_mask,
21+
create_varlen_cu_seqs,
2122
FlexAttentionWrapper,
2223
get_causal_mask_mod,
2324
get_document_mask_mod,
@@ -227,6 +228,7 @@ def forward(
227228
"""
228229

229230
bs, seqlen, _ = x.shape
231+
230232
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
231233

232234
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
@@ -236,20 +238,7 @@ def forward(
236238
xk = xk.view(bs, seqlen, -1, self.head_dim)
237239
xv = xv.view(bs, seqlen, -1, self.head_dim)
238240

239-
if self.use_varlen_attn:
240-
true_seq_len = freqs_cis.shape[0]
241-
total_tokens = xq.shape[1]
242-
243-
true_bs = total_tokens // true_seq_len
244-
xq = xq.view(true_bs, true_seq_len, -1, self.head_dim)
245-
xk = xk.view(true_bs, true_seq_len, -1, self.head_dim)
246-
247-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
248-
249-
xq = xq.view(1, total_tokens, -1, self.head_dim)
250-
xk = xk.view(1, total_tokens, -1, self.head_dim)
251-
else:
252-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
241+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
253242

254243
# repeat k/v heads if n_kv_heads < n_heads
255244
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
@@ -259,18 +248,16 @@ def forward(
259248
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
260249
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
261250

262-
assert (
263-
isinstance(attention_masks, BlockMask) or attention_masks is None
264-
), attention_masks
265-
266251
if self.use_flex_attn:
267252
assert isinstance(attention_masks, BlockMask), attention_masks
268253
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
269254
elif self.use_varlen_attn:
270-
cu_seq_q = kwargs.get("cu_seq_q")
271-
cu_seq_k = kwargs.get("cu_seq_k")
272-
max_q = kwargs.get("max_q")
273-
max_k = kwargs.get("max_k")
255+
assert isinstance(attention_masks, VarlenMetadata), attention_masks
256+
257+
cu_seq_q = attention_masks.cu_seq_q
258+
cu_seq_k = attention_masks.cu_seq_k
259+
max_q = attention_masks.max_q
260+
max_k = attention_masks.max_k
274261

275262
n_local_heads = xq.shape[1]
276263
xq_packed = (
@@ -515,6 +502,8 @@ def get_attention_masks(
515502
case "block_causal":
516503
B = input_batch.shape[0]
517504
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
505+
case "varlen_attn":
506+
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
518507
case _:
519508
raise ValueError(
520509
f"Unknown attention mask type: {self.model_args.attn_mask_type}"

torchtitan/models/llama3/train_configs/llama3_8b_varlen.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ save_tb_folder = "tb"
1616

1717
[model]
1818
name = "llama3"
19-
flavor = "8B"
19+
flavor = "8B_varlen"
2020
hf_assets_path = "./assets/hf/Llama-3.1-8B"
2121
# converters = ["float8"]
2222

torchtitan/protocols/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
import torch.nn as nn
1313

1414
from torch.nn.attention.flex_attention import BlockMask
15+
from torch.nn.attention.varlen import VarlenMetadata
1516

1617
from torchtitan.components.tokenizer import BaseTokenizer
1718

1819
from torchtitan.config import JobConfig
1920

2021

21-
AttentionMasksType = dict[str, BlockMask] | BlockMask
22+
AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata
2223

2324

2425
@dataclass

0 commit comments

Comments
 (0)