Skip to content

Commit af8e3e8

Browse files
committed
correct loss/perf
1 parent 93a5bac commit af8e3e8

File tree

6 files changed

+24
-23
lines changed

6 files changed

+24
-23
lines changed

torchtitan/hf_datasets/text_datasets.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def varlen_collate_fn(batch):
7979
Packed (input_dict, label) with collapsed batch dimension
8080
"""
8181
if len(batch) == 1:
82-
# Single sample - already packed
8382
input_dict, label = batch[0]
8483
return {
8584
"input": input_dict["input"].unsqueeze(0), # [1, seq_len]
@@ -89,7 +88,6 @@ def varlen_collate_fn(batch):
8988
"max_k": input_dict["max_k"],
9089
}, label.unsqueeze(0) # [1, seq_len]
9190

92-
# Multiple samples - pack them together
9391
inputs = []
9492
labels = []
9593
cu_seqlens_list = []
@@ -100,23 +98,17 @@ def varlen_collate_fn(batch):
10098
inputs.append(input_dict["input"])
10199
labels.append(label)
102100

103-
# Get cu_seqlens from this sample and adjust by offset
104101
cu_seqlens = input_dict["cu_seq_q"]
105-
# Don't include the last boundary (we'll add it at the end)
106102
cu_seqlens_adjusted = cu_seqlens[:-1] + offset
107103
cu_seqlens_list.append(cu_seqlens_adjusted)
108104

109-
# Track maximum sequence length across all samples
110105
max_seqlen = max(max_seqlen, input_dict["max_q"])
111106

112-
# Update offset for next sample
113107
offset += len(input_dict["input"])
114108

115-
# Concatenate all inputs and labels
116-
packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # Shape: [total_tokens]
117-
packed_label = torch.cat(labels, dim=0).unsqueeze(0) # Shape: [total_tokens]
109+
packed_input = torch.cat(inputs, dim=0).unsqueeze(0) # shape: [1, total_tokens]
110+
packed_label = torch.cat(labels, dim=0).unsqueeze(0) # shape: [1, total_tokens]
118111

119-
# Combine all cu_seqlens and add final boundary
120112
packed_cu_seqlens = torch.cat(
121113
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32)]
122114
)
@@ -189,7 +181,6 @@ def __iter__(self):
189181

190182
# marks where this current document ends
191183
if self.use_varlen_attn:
192-
# if self.use_varlen_attn or self.use_flex_attn:
193184
self._boundary_buffer.append(len(self._token_buffer))
194185

195186
while len(self._token_buffer) >= max_buffer_token_len:
@@ -198,19 +189,16 @@ def __iter__(self):
198189
# update tokens to the remaining tokens
199190
self._token_buffer = self._token_buffer[max_buffer_token_len:]
200191

201-
input = x[:-1] # print device here
192+
input = x[:-1]
202193
label = x[1:]
203194

204195
if self.use_varlen_attn:
205-
# if self.use_varlen_attn or self.use_flex_attn:
206196
boundaries_in_window = [
207197
b for b in self._boundary_buffer
208198
if b <= max_buffer_token_len
209199
]
210200

211201
cu_seqlens = torch.tensor(boundaries_in_window, dtype=torch.int32)
212-
# print device here
213-
214202

215203
self._boundary_buffer = [
216204
b - max_buffer_token_len

torchtitan/models/llama3/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
ffn_dim_multiplier=1.3,
5656
multiple_of=1024,
5757
rope_theta=500000,
58-
# use_flex_attn=True,
59-
# attn_mask_type="block_causal",
60-
use_varlen_attn=True,
58+
use_flex_attn=True,
59+
attn_mask_type="block_causal",
60+
# use_varlen_attn=True,
6161
),
6262
"70B": TransformerModelArgs(
6363
dim=8192,

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ def apply_rotary_emb(
134134
Returns:
135135
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
136136
"""
137+
137138
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
138139
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
140+
139141
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
140142
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
141143
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
@@ -282,7 +284,18 @@ def forward(
282284
cu_seq_q = kwargs.get("cu_seq_q_list")
283285
assert(cu_seq_q is not None)
284286
assert(type(cu_seq_q) is list)
285-
xq, xk = self._apply_rotary_per_sequence(xq, xk, freqs_cis, cu_seq_q)
287+
288+
true_seq_len = freqs_cis.shape[0]
289+
total_tokens = xq.shape[1]
290+
291+
true_bs = total_tokens // true_seq_len
292+
xq = xq.view(true_bs, true_seq_len, -1, self.head_dim)
293+
xk = xk.view(true_bs, true_seq_len, -1, self.head_dim)
294+
295+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
296+
297+
xq = xq.view(1, total_tokens, -1, self.head_dim)
298+
xk = xk.view(1, total_tokens, -1, self.head_dim)
286299
else:
287300
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
288301

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ enable_wandb = false
1919

2020
[model]
2121
name = "llama3"
22-
flavor = "debugmodel_varlen_attn"
22+
flavor = "debugmodel_flex_attn"
2323
# flavor = "debugmodel_flex_attn"
2424
# test folder with tokenizer.json, for debug purpose only
2525
hf_assets_path = "./tests/assets/tokenizer"

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "Llama 3 8B training"
66

77
[profiling]
88
enable_profiling = true
9-
save_traces_folder = "profile_trace"
9+
save_traces_folder = "flex_profile_trace"
1010
profile_freq = 100
1111

1212
[metrics]
@@ -32,7 +32,7 @@ warmup_steps = 200 # lr scheduler warm up
3232
local_batch_size = 1
3333
seq_len = 8192
3434
max_norm = 1.0 # grad norm clipping
35-
steps = 100
35+
steps = 1000
3636
dataset = "c4"
3737

3838
[parallelism]

torchtitan/tools/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def trace_handler(prof):
7676
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7777
on_trace_ready=trace_handler,
7878
record_shapes=True,
79-
# with_stack=True, # python stack
79+
with_stack=True, # python stack
8080
) as torch_profiler:
8181
torch_profiler.step_num = global_step
8282
yield torch_profiler

0 commit comments

Comments
 (0)