Skip to content

Commit accfa1f

Browse files
committed
fix bugs in PP
1 parent 91838de commit accfa1f

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

torchtitan/experiments/train_llama_hf/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
from loss import cross_entropy_loss_hf
109
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
1110

1211
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
1312
from torchtitan.experiments.train_llama_hf.dataset import (
1413
build_pos_included_hf_dataloader,
1514
)
15+
from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf
1616
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1717

1818
from .parallelize_llama import parallelize_llama

torchtitan/experiments/train_llama_hf/model/parallelize_llama.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ def apply_tp(
167167
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
168168
# by folding (and unfolding) the batch dimension and the sequence dimension.
169169
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
170-
for transformer_block in model.model.layers:
170+
if isinstance(model.model.layers, nn.ModuleDict):
171+
transformer_blocks = model.model.layers.values()
172+
else:
173+
transformer_blocks = model.model.layers
174+
for transformer_block in transformer_blocks:
171175
layer_plan = {
172176
"input_layernorm": SequenceParallel(),
173177
"self_attn": prepare_module_input(
@@ -260,8 +264,12 @@ def apply_fsdp(
260264
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
261265
if cpu_offload:
262266
fsdp_config["offload_policy"] = CPUOffloadPolicy()
267+
if isinstance(model.model.layers, nn.ModuleDict):
268+
layer_items = [(int(k), v) for (k, v) in model.model.layers.items()]
269+
else:
270+
layer_items = list(enumerate(model.model.layers))
263271

264-
for layer_id, transformer_block in enumerate(model.model.layers):
272+
for layer_id, transformer_block in layer_items:
265273
if reshard_after_forward_policy == "always":
266274
reshard_after_forward = True
267275
elif reshard_after_forward_policy == "never":
@@ -274,7 +282,7 @@ def apply_fsdp(
274282
else:
275283
# As an optimization, do not reshard after forward for the last
276284
# transformer block since FSDP would prefetch it immediately
277-
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
285+
reshard_after_forward = layer_id < len(layer_items) - 1
278286
else:
279287
raise ValueError(
280288
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."

torchtitan/experiments/train_llama_hf/model/pipeline_llama.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from torchtitan.tools.logging import logger
3131

32-
3332
DeviceType = Union[int, str, torch.device]
3433

3534

@@ -87,8 +86,10 @@ def forward(
8786
# create position embeddings to be shared across the decoder layers
8887
position_embeddings = self.rotary_emb(hidden_states, position_ids)
8988

90-
# decoder layers
91-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
89+
# decoder layers, ok since ModuleDict is ordered
90+
for decoder_layer in list(self.layers.values())[
91+
: self.config.num_hidden_layers
92+
]:
9293

9394
if self.gradient_checkpointing and self.training:
9495
layer_outputs = self._gradient_checkpointing_func(
@@ -217,6 +218,10 @@ def pipeline_llama(
217218
model_config: PretrainedConfig,
218219
loss_fn: Callable[..., torch.Tensor],
219220
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
221+
logger.info("Changing model.model.layers to nn.ModuleDict")
222+
model.model.layers = nn.ModuleDict(
223+
{str(i): layer for i, layer in enumerate(model.model.layers)}
224+
)
220225
logger.info(
221226
"Patching Llama forward method for pipeline parallelism, it will disable some features of orignal HF model"
222227
)
@@ -277,20 +282,14 @@ def _build_stage(
277282
model.model.embed_tokens = None
278283

279284
drop_layers = start_layer is not None
280-
del_indexes = []
281-
for i in range(len(model.model.layers)):
285+
for name in list(model.model.layers.keys()):
282286
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
283-
if f"layers.{i}" == start_layer:
287+
if f"layers.{name}" == start_layer:
284288
drop_layers = False
285-
if f"layers.{i}" == stop_layer:
289+
if f"layers.{name}" == stop_layer:
286290
drop_layers = True
287291
if drop_layers:
288-
del_indexes.append(i)
289-
290-
# delete layers in reverse order to avoid index shifting
291-
del_indexes.reverse()
292-
for i in del_indexes:
293-
del model.model.layers[i]
292+
del model.model.layers[name]
294293

295294
if not is_last:
296295
model.model.norm = None

torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
normalize_state_dict_key,
1717
)
1818

19+
from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf
20+
1921
from torchtitan.experiments.train_llama_hf.model.parallelize_llama import (
2022
apply_fsdp,
2123
apply_tp,
2224
)
23-
from torchtitan.experiments.train_llama_hf.model.pipeline_llama import (
24-
pipeline_llama_manual_split,
25-
)
25+
from torchtitan.experiments.train_llama_hf.model.pipeline_llama import pipeline_llama
2626

2727

2828
def main(job_config: JobConfig):
@@ -52,13 +52,14 @@ def main(job_config: JobConfig):
5252
# apply parallelisms
5353
if parallel_dims.pp_enabled:
5454
# apply PT-D Pipeline Parallel
55-
_, model_parts = pipeline_llama_manual_split(
55+
_, model_parts, _, _ = pipeline_llama(
5656
model,
5757
world_mesh["pp"],
5858
parallel_dims,
5959
job_config,
6060
device,
6161
model_config,
62+
loss_fn=cross_entropy_loss_hf,
6263
)
6364
else:
6465
model_parts = [model]

0 commit comments

Comments
 (0)