Skip to content

Commit 8e70c82

Browse files
author
Andrew Gu
committed
Added type annotations and more stylistic changes
ghstack-source-id: 1bd5b9d Pull Request resolved: #449
1 parent 2860493 commit 8e70c82

File tree

1 file changed

+69
-42
lines changed

1 file changed

+69
-42
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
import copy
1111
from collections import defaultdict
12-
from typing import Dict, Tuple
12+
from typing import Tuple, TYPE_CHECKING, Union
1313

1414
import torch
15+
import torch.nn as nn
16+
from torch.distributed import DeviceMesh
1517

1618
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1719
from torch.distributed._tensor import Replicate, Shard
@@ -29,8 +31,15 @@
2931

3032
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3133
from torchtitan.logging_utils import logger
34+
from torchtitan.models.llama.model import ModelArgs
3235
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
3336

37+
if TYPE_CHECKING:
38+
from torchtitan.parallelisms import ParallelDims
39+
40+
41+
DeviceType = Union[int, str, torch.device]
42+
3443
# for selective AC
3544
no_recompute_list = {
3645
torch.ops.aten.mm.default,
@@ -125,23 +134,30 @@ def get_tp_parallel_strategy(
125134

126135

127136
def pipeline_llama(
128-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
137+
model: nn.Module,
138+
world_mesh: DeviceMesh,
139+
parallel_dims: "ParallelDims",
140+
job_config: JobConfig,
141+
device: DeviceType,
142+
model_config: ModelArgs,
129143
):
130-
if job_config.experimental.pipeline_parallel_split_mode == "manual":
144+
split_mode = job_config.experimental.pipeline_parallel_split_mode
145+
valid_split_modes = ("manual", "tracer")
146+
if split_mode not in valid_split_modes:
147+
raise ValueError(
148+
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
149+
)
150+
if split_mode == "manual":
131151
return pipeline_llama_manual(
132152
model, world_mesh, parallel_dims, job_config, device, model_config
133153
)
134-
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
154+
elif split_mode == "tracer":
135155
return pipeline_llama_tracer(
136156
model, world_mesh, parallel_dims, job_config, device, model_config
137157
)
138-
else:
139-
raise NotImplementedError(
140-
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
141-
)
142158

143159

144-
def _llama_trace_input(job_config, model_config, device="meta"):
160+
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
145161
"""Get meta tensors with the right input shapes used for tracing"""
146162
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
147163
tokens = torch.randint(
@@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
153169
def _mixed_precision_dtype(
154170
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
155171
) -> torch.dtype:
156-
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
172+
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
157173
mp_arg = job_config.training.mixed_precision_param
158174
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default
159175

160176

161177
def pipeline_llama_manual(
162-
whole_model,
163-
world_mesh,
164-
parallel_dims,
178+
whole_model: nn.Module,
179+
world_mesh: DeviceMesh,
180+
parallel_dims: "ParallelDims",
165181
job_config: JobConfig,
166-
device,
167-
model_config: Dict,
182+
device: DeviceType,
183+
model_config: ModelArgs,
168184
):
169185
"""
170186
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
262278

263279

264280
def pipeline_llama_tracer(
265-
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
281+
model: nn.Module,
282+
world_mesh: DeviceMesh,
283+
parallel_dims: "ParallelDims",
284+
job_config: JobConfig,
285+
device: DeviceType,
286+
model_config: ModelArgs,
266287
):
267288
if job_config.model.norm_type == "fused_rmsnorm":
268-
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
269-
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
289+
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
290+
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
291+
# fused_rmsnorm
270292
raise NotImplementedError(
271-
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
293+
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
272294
)
273-
274-
if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
295+
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
275296
raise NotImplementedError(
276-
"pipeline tracer doesn't work with fsdp mixed precision currently. "
277-
"To work around, edit fsdp mixed precision config to use fp32."
297+
"Pipeline tracer does not work with FSDP mixed precision yet. "
298+
"To work around, set mixed_precision_param to float32."
278299
)
279300

280301
pp_mesh = world_mesh["pp"]
@@ -310,10 +331,13 @@ def pipeline_llama_tracer(
310331
return (stages, models)
311332

312333

313-
def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
314-
"""
315-
Apply tensor parallelism.
316-
"""
334+
def apply_tp(
335+
model: nn.Module,
336+
world_mesh: DeviceMesh,
337+
parallel_dims: "ParallelDims",
338+
job_config: JobConfig,
339+
):
340+
"""Apply tensor parallelism."""
317341

318342
tp_mesh = world_mesh["tp"]
319343
# Parallel styles used for transformer block linear weights and their
@@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
392416
return model
393417

394418

395-
def apply_ac(model, job_config: JobConfig):
396-
"""
397-
Apply activation checkpointing to the model.
398-
"""
419+
def apply_ac(model: nn.Module, job_config: JobConfig):
420+
"""Apply activation checkpointing to the model."""
399421

400422
ac_config = job_config.activation_checkpoint
401423

@@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig):
407429
return model
408430

409431

410-
def apply_compile(model, job_config: JobConfig):
411-
"""
412-
Apply torch.compile to the model.
413-
"""
432+
def apply_compile(model: nn.Module, job_config: JobConfig):
433+
"""Apply torch.compile to each transformer block."""
414434

415435
if job_config.model.norm_type == "fused_rmsnorm":
416436
raise NotImplementedError(
417-
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
437+
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
418438
)
419439

420440
for layer_id, transformer_block in model.layers.named_children():
421-
# turn on per-transformer block compile after AC wrapping and before FSDP
422441
# TODO: dynamic shape have some issues so we turn it off for now.
423442
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
424443
# compile time.
@@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig):
430449
return model
431450

432451

433-
def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
434-
"""
435-
Apply data parallelism (FSDP2) to the model.
436-
"""
452+
def apply_dp(
453+
model: nn.Module,
454+
world_mesh: DeviceMesh,
455+
parallel_dims: "ParallelDims",
456+
job_config: JobConfig,
457+
):
458+
"""Apply data parallelism (FSDP2) to the model."""
437459

438460
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
439461
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
@@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
466488
return model
467489

468490

469-
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
491+
def parallelize_llama(
492+
model: nn.Module,
493+
world_mesh: DeviceMesh,
494+
parallel_dims: "ParallelDims",
495+
job_config: JobConfig,
496+
):
470497
"""
471498
Apply tensor parallelism, activation checkpointing, torch.compile, and data
472499
parallelism to the model.

0 commit comments

Comments
 (0)