9
9
10
10
import copy
11
11
from collections import defaultdict
12
- from typing import Dict , Tuple
12
+ from typing import Tuple , TYPE_CHECKING , Union
13
13
14
14
import torch
15
+ import torch .nn as nn
16
+ from torch .distributed import DeviceMesh
15
17
16
18
from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
17
19
from torch .distributed ._tensor import Replicate , Shard
29
31
30
32
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
31
33
from torchtitan .logging_utils import logger
34
+ from torchtitan .models .llama .model import ModelArgs
32
35
from torchtitan .parallelisms .pipelining_utils import stage_ids_this_rank
33
36
37
+ if TYPE_CHECKING :
38
+ from torchtitan .parallelisms import ParallelDims
39
+
40
+
41
+ DeviceType = Union [int , str , torch .device ]
42
+
34
43
# for selective AC
35
44
no_recompute_list = {
36
45
torch .ops .aten .mm .default ,
@@ -112,23 +121,27 @@ def get_tp_parallel_strategy(
112
121
113
122
114
123
def pipeline_llama (
115
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
124
+ model : nn .Module ,
125
+ world_mesh : DeviceMesh ,
126
+ parallel_dims : "ParallelDims" ,
127
+ job_config : JobConfig ,
128
+ device : DeviceType ,
129
+ model_config : ModelArgs ,
116
130
):
117
- if job_config .experimental .pipeline_parallel_split_mode == "manual" :
131
+ split_mode = job_config .experimental .pipeline_parallel_split_mode
132
+ if split_mode == "manual" :
118
133
return pipeline_llama_manual (
119
134
model , world_mesh , parallel_dims , job_config , device , model_config
120
135
)
121
- elif job_config . experimental . pipeline_parallel_split_mode == "tracer" :
136
+ elif split_mode == "tracer" :
122
137
return pipeline_llama_tracer (
123
138
model , world_mesh , parallel_dims , job_config , device , model_config
124
139
)
125
140
else :
126
- raise NotImplementedError (
127
- f"{ job_config .experimental .pipeline_parallel_split_mode } is not a valid split mode"
128
- )
141
+ raise NotImplementedError (f"{ split_mode } is not a valid split mode" )
129
142
130
143
131
- def _llama_trace_input (job_config , model_config , device = "meta" ):
144
+ def _llama_trace_input (job_config : JobConfig , model_config : ModelArgs , device = "meta" ):
132
145
"""Get meta tensors with the right input shapes used for tracing"""
133
146
tokens_shape = (job_config .training .batch_size , job_config .training .seq_len )
134
147
tokens = torch .randint (
@@ -140,18 +153,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
140
153
def _mixed_precision_dtype (
141
154
job_config : JobConfig , parallel_dims , default : torch .dtype = torch .float32
142
155
) -> torch .dtype :
143
- """Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
156
+ """Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
144
157
mp_arg = job_config .training .mixed_precision_param
145
158
return TORCH_DTYPE_MAP [mp_arg ] if parallel_dims .dp_enabled else default
146
159
147
160
148
161
def pipeline_llama_manual (
149
- whole_model ,
150
- world_mesh ,
151
- parallel_dims ,
162
+ whole_model : nn . Module ,
163
+ world_mesh : DeviceMesh ,
164
+ parallel_dims : "ParallelDims" ,
152
165
job_config : JobConfig ,
153
- device ,
154
- model_config : Dict ,
166
+ device : DeviceType ,
167
+ model_config : ModelArgs ,
155
168
):
156
169
"""
157
170
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
@@ -249,19 +262,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
249
262
250
263
251
264
def pipeline_llama_tracer (
252
- model , world_mesh , parallel_dims , job_config : JobConfig , device , model_config : Dict
265
+ model : nn .Module ,
266
+ world_mesh : DeviceMesh ,
267
+ parallel_dims : "ParallelDims" ,
268
+ job_config : JobConfig ,
269
+ device : DeviceType ,
270
+ model_config : ModelArgs ,
253
271
):
254
- if job_config .model .norm_type == "fused_rmsnorm" :
255
- # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
256
- # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
272
+ if _mixed_precision_dtype (job_config , parallel_dims ) != torch .float32 :
257
273
raise NotImplementedError (
258
- "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
259
- )
260
-
261
- if _mixed_precision_dtype (job_config , parallel_dims ) == torch .bfloat16 :
262
- raise NotImplementedError (
263
- "pipeline tracer doesn't work with fsdp mixed precision currently. "
264
- "To work around, edit fsdp mixed precision config to use fp32."
274
+ "Pipeline tracer does not work with FSDP mixed precision yet. "
275
+ "To work around, set mixed_precision_param to float32."
265
276
)
266
277
267
278
pp_mesh = world_mesh ["pp" ]
@@ -297,10 +308,13 @@ def pipeline_llama_tracer(
297
308
return (stages , models )
298
309
299
310
300
- def apply_tp (model , world_mesh , parallel_dims , job_config : JobConfig ):
301
- """
302
- Apply tensor parallelism.
303
- """
311
+ def apply_tp (
312
+ model : nn .Module ,
313
+ world_mesh : DeviceMesh ,
314
+ parallel_dims : "ParallelDims" ,
315
+ job_config : JobConfig ,
316
+ ):
317
+ """Apply tensor parallelism."""
304
318
305
319
tp_mesh = world_mesh ["tp" ]
306
320
# Parallel styles for transformer block linear weights may be different for
@@ -379,10 +393,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
379
393
return model
380
394
381
395
382
- def apply_ac (model , job_config : JobConfig ):
383
- """
384
- Apply activation checkpointing to the model.
385
- """
396
+ def apply_ac (model : nn .Module , job_config : JobConfig ):
397
+ """Apply activation checkpointing to the model."""
386
398
387
399
ac_config = job_config .activation_checkpoint
388
400
@@ -394,18 +406,10 @@ def apply_ac(model, job_config: JobConfig):
394
406
return model
395
407
396
408
397
- def apply_compile (model , job_config : JobConfig ):
398
- """
399
- Apply torch.compile to the model.
400
- """
401
-
402
- if job_config .model .norm_type == "fused_rmsnorm" :
403
- raise NotImplementedError (
404
- "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
405
- )
409
+ def apply_compile (model : nn .Module , job_config : JobConfig ):
410
+ """Apply torch.compile to each transformer block."""
406
411
407
412
for layer_id , transformer_block in model .layers .named_children ():
408
- # turn on per-transformer block compile after AC wrapping and before FSDP
409
413
# TODO: dynamic shape have some issues so we turn it off for now.
410
414
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
411
415
# compile time.
@@ -417,10 +421,13 @@ def apply_compile(model, job_config: JobConfig):
417
421
return model
418
422
419
423
420
- def apply_dp (model , world_mesh , parallel_dims , job_config : JobConfig ):
421
- """
422
- Apply data parallelism (FSDP2) to the model.
423
- """
424
+ def apply_dp (
425
+ model : nn .Module ,
426
+ world_mesh : DeviceMesh ,
427
+ parallel_dims : "ParallelDims" ,
428
+ job_config : JobConfig ,
429
+ ):
430
+ """Apply data parallelism (FSDP2) to the model."""
424
431
425
432
dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
426
433
assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
@@ -453,7 +460,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
453
460
return model
454
461
455
462
456
- def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
463
+ def parallelize_llama (
464
+ model : nn .Module ,
465
+ world_mesh : DeviceMesh ,
466
+ parallel_dims : "ParallelDims" ,
467
+ job_config : JobConfig ,
468
+ ):
457
469
"""
458
470
Apply tensor parallelism, activation checkpointing, torch.compile, and data
459
471
parallelism to the model.
0 commit comments