9
9
10
10
import copy
11
11
from collections import defaultdict
12
- from typing import Dict , Tuple
12
+ from typing import Tuple , TYPE_CHECKING , Union
13
13
14
14
import torch
15
+ import torch .nn as nn
16
+ from torch .distributed import DeviceMesh
15
17
16
18
from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
17
19
from torch .distributed ._tensor import Replicate , Shard
29
31
30
32
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
31
33
from torchtitan .logging_utils import logger
34
+ from torchtitan .models .llama .model import ModelArgs
32
35
from torchtitan .parallelisms .pipelining_utils import stage_ids_this_rank
33
36
37
+ if TYPE_CHECKING :
38
+ from torchtitan .parallelisms import ParallelDims
39
+
40
+
41
+ DeviceType = Union [int , str , torch .device ]
42
+
34
43
# for selective AC
35
44
no_recompute_list = {
36
45
torch .ops .aten .mm .default ,
@@ -125,23 +134,30 @@ def get_tp_parallel_strategy(
125
134
126
135
127
136
def pipeline_llama (
128
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
137
+ model : nn .Module ,
138
+ world_mesh : DeviceMesh ,
139
+ parallel_dims : "ParallelDims" ,
140
+ job_config : JobConfig ,
141
+ device : DeviceType ,
142
+ model_config : ModelArgs ,
129
143
):
130
- if job_config .experimental .pipeline_parallel_split_mode == "manual" :
144
+ split_mode = job_config .experimental .pipeline_parallel_split_mode
145
+ valid_split_modes = ("manual" , "tracer" )
146
+ if split_mode not in valid_split_modes :
147
+ raise ValueError (
148
+ f"Invalid split mode: { split_mode } . Valid split modes: { valid_split_modes } "
149
+ )
150
+ if split_mode == "manual" :
131
151
return pipeline_llama_manual (
132
152
model , world_mesh , parallel_dims , job_config , device , model_config
133
153
)
134
- elif job_config . experimental . pipeline_parallel_split_mode == "tracer" :
154
+ elif split_mode == "tracer" :
135
155
return pipeline_llama_tracer (
136
156
model , world_mesh , parallel_dims , job_config , device , model_config
137
157
)
138
- else :
139
- raise NotImplementedError (
140
- f"{ job_config .experimental .pipeline_parallel_split_mode } is not a valid split mode"
141
- )
142
158
143
159
144
- def _llama_trace_input (job_config , model_config , device = "meta" ):
160
+ def _llama_trace_input (job_config : JobConfig , model_config : ModelArgs , device = "meta" ):
145
161
"""Get meta tensors with the right input shapes used for tracing"""
146
162
tokens_shape = (job_config .training .batch_size , job_config .training .seq_len )
147
163
tokens = torch .randint (
@@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
153
169
def _mixed_precision_dtype (
154
170
job_config : JobConfig , parallel_dims , default : torch .dtype = torch .float32
155
171
) -> torch .dtype :
156
- """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
172
+ """Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
157
173
mp_arg = job_config .training .mixed_precision_param
158
174
return TORCH_DTYPE_MAP [mp_arg ] if parallel_dims .dp_enabled else default
159
175
160
176
161
177
def pipeline_llama_manual (
162
- whole_model ,
163
- world_mesh ,
164
- parallel_dims ,
178
+ whole_model : nn . Module ,
179
+ world_mesh : DeviceMesh ,
180
+ parallel_dims : "ParallelDims" ,
165
181
job_config : JobConfig ,
166
- device ,
167
- model_config : Dict ,
182
+ device : DeviceType ,
183
+ model_config : ModelArgs ,
168
184
):
169
185
"""
170
186
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
262
278
263
279
264
280
def pipeline_llama_tracer (
265
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
281
+ model : nn .Module ,
282
+ world_mesh : DeviceMesh ,
283
+ parallel_dims : "ParallelDims" ,
284
+ job_config : JobConfig ,
285
+ device : DeviceType ,
286
+ model_config : ModelArgs ,
266
287
):
267
288
if job_config .model .norm_type == "fused_rmsnorm" :
268
- # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
269
- # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
289
+ # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
290
+ # invocation stride in strict mode from `if dy.stride(-1) != 1:` in
291
+ # fused_rmsnorm
270
292
raise NotImplementedError (
271
- "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error) . Please use layernorm or rmsnorm ."
293
+ "fused_rmsnorm is not compatible with Pipeline Tracer yet . Please use rmsnorm or layernorm ."
272
294
)
273
-
274
- if _mixed_precision_dtype (job_config , parallel_dims ) == torch .bfloat16 :
295
+ if _mixed_precision_dtype (job_config , parallel_dims ) != torch .float32 :
275
296
raise NotImplementedError (
276
- "pipeline tracer doesn't work with fsdp mixed precision currently . "
277
- "To work around, edit fsdp mixed precision config to use fp32 ."
297
+ "Pipeline tracer does not work with FSDP mixed precision yet . "
298
+ "To work around, set mixed_precision_param to float32 ."
278
299
)
279
300
280
301
pp_mesh = world_mesh ["pp" ]
@@ -310,10 +331,13 @@ def pipeline_llama_tracer(
310
331
return (stages , models )
311
332
312
333
313
- def apply_tp (model , world_mesh , parallel_dims , job_config : JobConfig ):
314
- """
315
- Apply tensor parallelism.
316
- """
334
+ def apply_tp (
335
+ model : nn .Module ,
336
+ world_mesh : DeviceMesh ,
337
+ parallel_dims : "ParallelDims" ,
338
+ job_config : JobConfig ,
339
+ ):
340
+ """Apply tensor parallelism."""
317
341
318
342
tp_mesh = world_mesh ["tp" ]
319
343
# Parallel styles used for transformer block linear weights and their
@@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
392
416
return model
393
417
394
418
395
- def apply_ac (model , job_config : JobConfig ):
396
- """
397
- Apply activation checkpointing to the model.
398
- """
419
+ def apply_ac (model : nn .Module , job_config : JobConfig ):
420
+ """Apply activation checkpointing to the model."""
399
421
400
422
ac_config = job_config .activation_checkpoint
401
423
@@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig):
407
429
return model
408
430
409
431
410
- def apply_compile (model , job_config : JobConfig ):
411
- """
412
- Apply torch.compile to the model.
413
- """
432
+ def apply_compile (model : nn .Module , job_config : JobConfig ):
433
+ """Apply torch.compile to each transformer block."""
414
434
415
435
if job_config .model .norm_type == "fused_rmsnorm" :
416
436
raise NotImplementedError (
417
- "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm ."
437
+ "fused_rmsnorm is not compatible with torch.compile yet . Please use rmsnorm or layernorm ."
418
438
)
419
439
420
440
for layer_id , transformer_block in model .layers .named_children ():
421
- # turn on per-transformer block compile after AC wrapping and before FSDP
422
441
# TODO: dynamic shape have some issues so we turn it off for now.
423
442
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
424
443
# compile time.
@@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig):
430
449
return model
431
450
432
451
433
- def apply_dp (model , world_mesh , parallel_dims , job_config : JobConfig ):
434
- """
435
- Apply data parallelism (FSDP2) to the model.
436
- """
452
+ def apply_dp (
453
+ model : nn .Module ,
454
+ world_mesh : DeviceMesh ,
455
+ parallel_dims : "ParallelDims" ,
456
+ job_config : JobConfig ,
457
+ ):
458
+ """Apply data parallelism (FSDP2) to the model."""
437
459
438
460
dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
439
461
assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
@@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
466
488
return model
467
489
468
490
469
- def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
491
+ def parallelize_llama (
492
+ model : nn .Module ,
493
+ world_mesh : DeviceMesh ,
494
+ parallel_dims : "ParallelDims" ,
495
+ job_config : JobConfig ,
496
+ ):
470
497
"""
471
498
Apply tensor parallelism, activation checkpointing, torch.compile, and data
472
499
parallelism to the model.
0 commit comments