9
9
10
10
import copy
11
11
from collections import defaultdict
12
- from typing import Dict , Tuple
12
+ from typing import Tuple , TYPE_CHECKING , Union
13
13
14
14
import torch
15
+ import torch .nn as nn
16
+ from torch .distributed import DeviceMesh
15
17
16
18
from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
17
19
from torch .distributed ._tensor import Replicate , Shard
29
31
30
32
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
31
33
from torchtitan .logging_utils import logger
34
+ from torchtitan .models .llama .model import ModelArgs
32
35
from torchtitan .parallelisms .pipelining_utils import stage_ids_this_rank
33
36
37
+ if TYPE_CHECKING :
38
+ from torchtitan .parallelisms import ParallelDims
39
+
40
+
41
+ DeviceType = Union [int , str , torch .device ]
42
+
34
43
# for selective AC
35
44
no_recompute_list = {
36
45
torch .ops .aten .mm .default ,
@@ -107,23 +116,27 @@ def get_tp_parallel_strategy(
107
116
108
117
109
118
def pipeline_llama (
110
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
119
+ model : nn .Module ,
120
+ world_mesh : DeviceMesh ,
121
+ parallel_dims : "ParallelDims" ,
122
+ job_config : JobConfig ,
123
+ device : DeviceType ,
124
+ model_config : ModelArgs ,
111
125
):
112
- if job_config .experimental .pipeline_parallel_split_mode == "manual" :
126
+ split_mode = job_config .experimental .pipeline_parallel_split_mode
127
+ if split_mode == "manual" :
113
128
return pipeline_llama_manual (
114
129
model , world_mesh , parallel_dims , job_config , device , model_config
115
130
)
116
- elif job_config . experimental . pipeline_parallel_split_mode == "tracer" :
131
+ elif split_mode == "tracer" :
117
132
return pipeline_llama_tracer (
118
133
model , world_mesh , parallel_dims , job_config , device , model_config
119
134
)
120
135
else :
121
- raise NotImplementedError (
122
- f"{ job_config .experimental .pipeline_parallel_split_mode } is not a valid split mode"
123
- )
136
+ raise NotImplementedError (f"{ split_mode } is not a valid split mode" )
124
137
125
138
126
- def _llama_trace_input (job_config , model_config , device = "meta" ):
139
+ def _llama_trace_input (job_config : JobConfig , model_config : ModelArgs , device = "meta" ):
127
140
"""Get meta tensors with the right input shapes used for tracing"""
128
141
tokens_shape = (job_config .training .batch_size , job_config .training .seq_len )
129
142
tokens = torch .randint (
@@ -135,18 +148,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
135
148
def _mixed_precision_dtype (
136
149
job_config : JobConfig , parallel_dims , default : torch .dtype = torch .float32
137
150
) -> torch .dtype :
138
- """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
151
+ """Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
139
152
mp_arg = job_config .training .mixed_precision_param
140
153
return TORCH_DTYPE_MAP [mp_arg ] if parallel_dims .dp_enabled else default
141
154
142
155
143
156
def pipeline_llama_manual (
144
- whole_model ,
145
- world_mesh ,
146
- parallel_dims ,
157
+ whole_model : nn . Module ,
158
+ world_mesh : DeviceMesh ,
159
+ parallel_dims : "ParallelDims" ,
147
160
job_config : JobConfig ,
148
- device ,
149
- model_config : Dict ,
161
+ device : DeviceType ,
162
+ model_config : ModelArgs ,
150
163
):
151
164
"""
152
165
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -244,19 +257,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
244
257
245
258
246
259
def pipeline_llama_tracer (
247
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
260
+ model : nn .Module ,
261
+ world_mesh : DeviceMesh ,
262
+ parallel_dims : "ParallelDims" ,
263
+ job_config : JobConfig ,
264
+ device : DeviceType ,
265
+ model_config : ModelArgs ,
248
266
):
249
- if job_config .model .norm_type == "fused_rmsnorm" :
250
- # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
251
- # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
267
+ if _mixed_precision_dtype (job_config , parallel_dims ) != torch .float32 :
252
268
raise NotImplementedError (
253
- "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
254
- )
255
-
256
- if _mixed_precision_dtype (job_config , parallel_dims ) == torch .bfloat16 :
257
- raise NotImplementedError (
258
- "pipeline tracer doesn't work with fsdp mixed precision currently. "
259
- "To work around, edit fsdp mixed precision config to use fp32."
269
+ "Pipeline tracer does not work with FSDP mixed precision yet. "
270
+ "To work around, set mixed_precision_param to float32."
260
271
)
261
272
262
273
pp_mesh = world_mesh ["pp" ]
@@ -292,10 +303,13 @@ def pipeline_llama_tracer(
292
303
return (stages , models )
293
304
294
305
295
- def apply_tp (model , world_mesh , parallel_dims , job_config : JobConfig ):
296
- """
297
- Apply tensor parallelism.
298
- """
306
+ def apply_tp (
307
+ model : nn .Module ,
308
+ world_mesh : DeviceMesh ,
309
+ parallel_dims : "ParallelDims" ,
310
+ job_config : JobConfig ,
311
+ ):
312
+ """Apply tensor parallelism."""
299
313
300
314
tp_mesh = world_mesh ["tp" ]
301
315
# Parallel styles for transformer block linear weights may be different for
@@ -374,10 +388,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
374
388
return model
375
389
376
390
377
- def apply_ac (model , job_config : JobConfig ):
378
- """
379
- Apply activation checkpointing to the model.
380
- """
391
+ def apply_ac (model : nn .Module , job_config : JobConfig ):
392
+ """Apply activation checkpointing to the model."""
381
393
382
394
ac_config = job_config .activation_checkpoint
383
395
@@ -389,18 +401,10 @@ def apply_ac(model, job_config: JobConfig):
389
401
return model
390
402
391
403
392
- def apply_compile (model , job_config : JobConfig ):
393
- """
394
- Apply torch.compile to the model.
395
- """
396
-
397
- if job_config .model .norm_type == "fused_rmsnorm" :
398
- raise NotImplementedError (
399
- "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
400
- )
404
+ def apply_compile (model : nn .Module , job_config : JobConfig ):
405
+ """Apply torch.compile to each transformer block."""
401
406
402
407
for layer_id , transformer_block in model .layers .named_children ():
403
- # turn on per-transformer block compile after AC wrapping and before FSDP
404
408
# TODO: dynamic shape have some issues so we turn it off for now.
405
409
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
406
410
# compile time.
@@ -412,10 +416,13 @@ def apply_compile(model, job_config: JobConfig):
412
416
return model
413
417
414
418
415
- def apply_dp (model , world_mesh , parallel_dims , job_config : JobConfig ):
416
- """
417
- Apply data parallelism (FSDP2) to the model.
418
- """
419
+ def apply_dp (
420
+ model : nn .Module ,
421
+ world_mesh : DeviceMesh ,
422
+ parallel_dims : "ParallelDims" ,
423
+ job_config : JobConfig ,
424
+ ):
425
+ """Apply data parallelism (FSDP2) to the model."""
419
426
420
427
dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
421
428
assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
@@ -448,7 +455,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
448
455
return model
449
456
450
457
451
- def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
458
+ def parallelize_llama (
459
+ model : nn .Module ,
460
+ world_mesh : DeviceMesh ,
461
+ parallel_dims : "ParallelDims" ,
462
+ job_config : JobConfig ,
463
+ ):
452
464
"""
453
465
Apply tensor parallelism, activation checkpointing, torch.compile, and data
454
466
parallelism to the model.
0 commit comments