Skip to content

Commit e824473

Browse files
author
Andrew Gu
committed
Moved more checks to config manager plus more stylistic changes
ghstack-source-id: 4dc6153 Pull Request resolved: #449
1 parent 348dd59 commit e824473

File tree

2 files changed

+78
-49
lines changed

2 files changed

+78
-49
lines changed

torchtitan/config_manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,28 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
533533
return args_dict
534534

535535
def _validate_config(self) -> None:
536-
# TODO: Add more mandatory validations
537536
assert self.model.name
538537
assert self.model.flavor
539538
assert self.model.tokenizer_path
540539

540+
pp_split_mode = self.experimental.pipeline_parallel_split_mode
541+
if pp_split_mode not in ("manual", "tracer"):
542+
raise ValueError(
543+
f"Invalid split mode: {self.experimental.pipeline_parallel_split_mode}"
544+
)
545+
if pp_split_mode == "tracer" and self.model.norm_type == "fused_rmsnorm":
546+
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
547+
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
548+
# fused_rmsnorm
549+
raise NotImplementedError(
550+
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
551+
)
552+
553+
if self.training.compile and self.model.norm_type == "fused_rmsnorm":
554+
raise NotImplementedError(
555+
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
556+
)
557+
541558
def parse_args_from_command_line(
542559
self, args_list
543560
) -> Tuple[argparse.Namespace, argparse.Namespace]:

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
import copy
1111
from collections import defaultdict
12-
from typing import Dict, Tuple
12+
from typing import Tuple, TYPE_CHECKING, Union
1313

1414
import torch
15+
import torch.nn as nn
16+
from torch.distributed import DeviceMesh
1517

1618
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1719
from torch.distributed._tensor import Replicate, Shard
@@ -29,8 +31,15 @@
2931

3032
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3133
from torchtitan.logging_utils import logger
34+
from torchtitan.models.llama.model import ModelArgs
3235
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
3336

37+
if TYPE_CHECKING:
38+
from torchtitan.parallelisms import ParallelDims
39+
40+
41+
DeviceType = Union[int, str, torch.device]
42+
3443
# for selective AC
3544
no_recompute_list = {
3645
torch.ops.aten.mm.default,
@@ -112,23 +121,27 @@ def get_tp_parallel_strategy(
112121

113122

114123
def pipeline_llama(
115-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
124+
model: nn.Module,
125+
world_mesh: DeviceMesh,
126+
parallel_dims: "ParallelDims",
127+
job_config: JobConfig,
128+
device: DeviceType,
129+
model_config: ModelArgs,
116130
):
117-
if job_config.experimental.pipeline_parallel_split_mode == "manual":
131+
split_mode = job_config.experimental.pipeline_parallel_split_mode
132+
if split_mode == "manual":
118133
return pipeline_llama_manual(
119134
model, world_mesh, parallel_dims, job_config, device, model_config
120135
)
121-
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
136+
elif split_mode == "tracer":
122137
return pipeline_llama_tracer(
123138
model, world_mesh, parallel_dims, job_config, device, model_config
124139
)
125140
else:
126-
raise NotImplementedError(
127-
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
128-
)
141+
raise NotImplementedError(f"{split_mode} is not a valid split mode")
129142

130143

131-
def _llama_trace_input(job_config, model_config, device="meta"):
144+
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
132145
"""Get meta tensors with the right input shapes used for tracing"""
133146
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
134147
tokens = torch.randint(
@@ -140,18 +153,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
140153
def _mixed_precision_dtype(
141154
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
142155
) -> torch.dtype:
143-
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
156+
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
144157
mp_arg = job_config.training.mixed_precision_param
145158
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default
146159

147160

148161
def pipeline_llama_manual(
149-
whole_model,
150-
world_mesh,
151-
parallel_dims,
162+
whole_model: nn.Module,
163+
world_mesh: DeviceMesh,
164+
parallel_dims: "ParallelDims",
152165
job_config: JobConfig,
153-
device,
154-
model_config: Dict,
166+
device: DeviceType,
167+
model_config: ModelArgs,
155168
):
156169
"""
157170
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -249,19 +262,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
249262

250263

251264
def pipeline_llama_tracer(
252-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
265+
model: nn.Module,
266+
world_mesh: DeviceMesh,
267+
parallel_dims: "ParallelDims",
268+
job_config: JobConfig,
269+
device: DeviceType,
270+
model_config: ModelArgs,
253271
):
254-
if job_config.model.norm_type == "fused_rmsnorm":
255-
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
256-
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
272+
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
257273
raise NotImplementedError(
258-
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
259-
)
260-
261-
if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
262-
raise NotImplementedError(
263-
"pipeline tracer doesn't work with fsdp mixed precision currently. "
264-
"To work around, edit fsdp mixed precision config to use fp32."
274+
"Pipeline tracer does not work with FSDP mixed precision yet. "
275+
"To work around, set mixed_precision_param to float32."
265276
)
266277

267278
pp_mesh = world_mesh["pp"]
@@ -297,10 +308,13 @@ def pipeline_llama_tracer(
297308
return (stages, models)
298309

299310

300-
def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
301-
"""
302-
Apply tensor parallelism.
303-
"""
311+
def apply_tp(
312+
model: nn.Module,
313+
world_mesh: DeviceMesh,
314+
parallel_dims: "ParallelDims",
315+
job_config: JobConfig,
316+
):
317+
"""Apply tensor parallelism."""
304318

305319
tp_mesh = world_mesh["tp"]
306320
# Parallel styles for transformer block linear weights may be different for
@@ -379,10 +393,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
379393
return model
380394

381395

382-
def apply_ac(model, job_config: JobConfig):
383-
"""
384-
Apply activation checkpointing to the model.
385-
"""
396+
def apply_ac(model: nn.Module, job_config: JobConfig):
397+
"""Apply activation checkpointing to the model."""
386398

387399
ac_config = job_config.activation_checkpoint
388400

@@ -394,18 +406,10 @@ def apply_ac(model, job_config: JobConfig):
394406
return model
395407

396408

397-
def apply_compile(model, job_config: JobConfig):
398-
"""
399-
Apply torch.compile to the model.
400-
"""
401-
402-
if job_config.model.norm_type == "fused_rmsnorm":
403-
raise NotImplementedError(
404-
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
405-
)
409+
def apply_compile(model: nn.Module, job_config: JobConfig):
410+
"""Apply torch.compile to each transformer block."""
406411

407412
for layer_id, transformer_block in model.layers.named_children():
408-
# turn on per-transformer block compile after AC wrapping and before FSDP
409413
# TODO: dynamic shape have some issues so we turn it off for now.
410414
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
411415
# compile time.
@@ -417,10 +421,13 @@ def apply_compile(model, job_config: JobConfig):
417421
return model
418422

419423

420-
def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
421-
"""
422-
Apply data parallelism (FSDP2) to the model.
423-
"""
424+
def apply_dp(
425+
model: nn.Module,
426+
world_mesh: DeviceMesh,
427+
parallel_dims: "ParallelDims",
428+
job_config: JobConfig,
429+
):
430+
"""Apply data parallelism (FSDP2) to the model."""
424431

425432
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
426433
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
@@ -453,7 +460,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
453460
return model
454461

455462

456-
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
463+
def parallelize_llama(
464+
model: nn.Module,
465+
world_mesh: DeviceMesh,
466+
parallel_dims: "ParallelDims",
467+
job_config: JobConfig,
468+
):
457469
"""
458470
Apply tensor parallelism, activation checkpointing, torch.compile, and data
459471
parallelism to the model.

0 commit comments

Comments
 (0)