Skip to content

Commit ea5dcdb

Browse files
author
Andrew Gu
committed
Moved more checks to config manager plus more stylistic changes
ghstack-source-id: e5dbadc Pull Request resolved: #449
1 parent e42e56f commit ea5dcdb

File tree

2 files changed

+78
-50
lines changed

2 files changed

+78
-50
lines changed

torchtitan/config_manager.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,12 +533,23 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
533533
return args_dict
534534

535535
def _validate_config(self) -> None:
536-
# TODO: Add more mandatory validations
537536
assert self.model.name
538537
assert self.model.flavor
539538
assert self.model.tokenizer_path
540539

541-
ac_config = self.activation_checkpoint
540+
pp_split_mode = self.experimental.pipeline_parallel_split_mode
541+
if pp_split_mode not in ("manual", "tracer"):
542+
raise ValueError(
543+
f"Invalid split mode: {self.experimental.pipeline_parallel_split_mode}"
544+
)
545+
if pp_split_mode == "tracer" and self.model.norm_type == "fused_rmsnorm":
546+
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
547+
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
548+
# fused_rmsnorm
549+
raise NotImplementedError(
550+
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
551+
)
552+
542553
ac_config = self.activation_checkpoint
543554
if ac_config.mode not in ("full", "selective", "none"):
544555
raise ValueError(f"Invalid AC mode: {ac_config.mode}")
@@ -549,6 +560,11 @@ def _validate_config(self) -> None:
549560
f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}"
550561
)
551562

563+
if self.training.compile and self.model.norm_type == "fused_rmsnorm":
564+
raise NotImplementedError(
565+
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
566+
)
567+
552568
def parse_args_from_command_line(
553569
self, args_list
554570
) -> Tuple[argparse.Namespace, argparse.Namespace]:

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
import copy
1111
from collections import defaultdict
12-
from typing import Dict, Tuple
12+
from typing import Tuple, TYPE_CHECKING, Union
1313

1414
import torch
15+
import torch.nn as nn
16+
from torch.distributed import DeviceMesh
1517

1618
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1719
from torch.distributed._tensor import Replicate, Shard
@@ -29,8 +31,15 @@
2931

3032
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3133
from torchtitan.logging_utils import logger
34+
from torchtitan.models.llama.model import ModelArgs
3235
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
3336

37+
if TYPE_CHECKING:
38+
from torchtitan.parallelisms import ParallelDims
39+
40+
41+
DeviceType = Union[int, str, torch.device]
42+
3443
# for selective AC
3544
no_recompute_list = {
3645
torch.ops.aten.mm.default,
@@ -107,23 +116,27 @@ def get_tp_parallel_strategy(
107116

108117

109118
def pipeline_llama(
110-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
119+
model: nn.Module,
120+
world_mesh: DeviceMesh,
121+
parallel_dims: "ParallelDims",
122+
job_config: JobConfig,
123+
device: DeviceType,
124+
model_config: ModelArgs,
111125
):
112-
if job_config.experimental.pipeline_parallel_split_mode == "manual":
126+
split_mode = job_config.experimental.pipeline_parallel_split_mode
127+
if split_mode == "manual":
113128
return pipeline_llama_manual(
114129
model, world_mesh, parallel_dims, job_config, device, model_config
115130
)
116-
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
131+
elif split_mode == "tracer":
117132
return pipeline_llama_tracer(
118133
model, world_mesh, parallel_dims, job_config, device, model_config
119134
)
120135
else:
121-
raise NotImplementedError(
122-
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
123-
)
136+
raise NotImplementedError(f"{split_mode} is not a valid split mode")
124137

125138

126-
def _llama_trace_input(job_config, model_config, device="meta"):
139+
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
127140
"""Get meta tensors with the right input shapes used for tracing"""
128141
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
129142
tokens = torch.randint(
@@ -135,18 +148,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
135148
def _mixed_precision_dtype(
136149
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
137150
) -> torch.dtype:
138-
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
151+
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
139152
mp_arg = job_config.training.mixed_precision_param
140153
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default
141154

142155

143156
def pipeline_llama_manual(
144-
whole_model,
145-
world_mesh,
146-
parallel_dims,
157+
whole_model: nn.Module,
158+
world_mesh: DeviceMesh,
159+
parallel_dims: "ParallelDims",
147160
job_config: JobConfig,
148-
device,
149-
model_config: Dict,
161+
device: DeviceType,
162+
model_config: ModelArgs,
150163
):
151164
"""
152165
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -244,19 +257,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
244257

245258

246259
def pipeline_llama_tracer(
247-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
260+
model: nn.Module,
261+
world_mesh: DeviceMesh,
262+
parallel_dims: "ParallelDims",
263+
job_config: JobConfig,
264+
device: DeviceType,
265+
model_config: ModelArgs,
248266
):
249-
if job_config.model.norm_type == "fused_rmsnorm":
250-
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
251-
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
267+
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
252268
raise NotImplementedError(
253-
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
254-
)
255-
256-
if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
257-
raise NotImplementedError(
258-
"pipeline tracer doesn't work with fsdp mixed precision currently. "
259-
"To work around, edit fsdp mixed precision config to use fp32."
269+
"Pipeline tracer does not work with FSDP mixed precision yet. "
270+
"To work around, set mixed_precision_param to float32."
260271
)
261272

262273
pp_mesh = world_mesh["pp"]
@@ -292,10 +303,13 @@ def pipeline_llama_tracer(
292303
return (stages, models)
293304

294305

295-
def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
296-
"""
297-
Apply tensor parallelism.
298-
"""
306+
def apply_tp(
307+
model: nn.Module,
308+
world_mesh: DeviceMesh,
309+
parallel_dims: "ParallelDims",
310+
job_config: JobConfig,
311+
):
312+
"""Apply tensor parallelism."""
299313

300314
tp_mesh = world_mesh["tp"]
301315
# Parallel styles for transformer block linear weights may be different for
@@ -374,10 +388,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
374388
return model
375389

376390

377-
def apply_ac(model, job_config: JobConfig):
378-
"""
379-
Apply activation checkpointing to the model.
380-
"""
391+
def apply_ac(model: nn.Module, job_config: JobConfig):
392+
"""Apply activation checkpointing to the model."""
381393

382394
ac_config = job_config.activation_checkpoint
383395

@@ -389,18 +401,10 @@ def apply_ac(model, job_config: JobConfig):
389401
return model
390402

391403

392-
def apply_compile(model, job_config: JobConfig):
393-
"""
394-
Apply torch.compile to the model.
395-
"""
396-
397-
if job_config.model.norm_type == "fused_rmsnorm":
398-
raise NotImplementedError(
399-
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
400-
)
404+
def apply_compile(model: nn.Module, job_config: JobConfig):
405+
"""Apply torch.compile to each transformer block."""
401406

402407
for layer_id, transformer_block in model.layers.named_children():
403-
# turn on per-transformer block compile after AC wrapping and before FSDP
404408
# TODO: dynamic shape have some issues so we turn it off for now.
405409
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
406410
# compile time.
@@ -412,10 +416,13 @@ def apply_compile(model, job_config: JobConfig):
412416
return model
413417

414418

415-
def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
416-
"""
417-
Apply data parallelism (FSDP2) to the model.
418-
"""
419+
def apply_dp(
420+
model: nn.Module,
421+
world_mesh: DeviceMesh,
422+
parallel_dims: "ParallelDims",
423+
job_config: JobConfig,
424+
):
425+
"""Apply data parallelism (FSDP2) to the model."""
419426

420427
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
421428
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
@@ -448,7 +455,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
448455
return model
449456

450457

451-
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
458+
def parallelize_llama(
459+
model: nn.Module,
460+
world_mesh: DeviceMesh,
461+
parallel_dims: "ParallelDims",
462+
job_config: JobConfig,
463+
):
452464
"""
453465
Apply tensor parallelism, activation checkpointing, torch.compile, and data
454466
parallelism to the model.

0 commit comments

Comments
 (0)