Skip to content

Commit 75cef61

Browse files
authored
Move compile into AutoParallel API (#77)
Requires upstream pytorch/pytorch#159814 to land first. (done) also updates joint_with_descriptors meta info to reflect that arg/placeholder shapes have changed (inputs and params are now sharded, not global, as they were when first traced). kicked off a verification run from this PR @ acaaf9c: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchtitan-64-whc-kf1llhnr tbm FSDP_eager:torchtitan-64-whc-p3s1bn compile_pr:torchtitan-64-whc-kf1llhnr compile_noac_from_update_post:torchtitan-64-fmassa-r4rvfnf6
1 parent 0cd63db commit 75cef61

File tree

4 files changed

+62
-13
lines changed

4 files changed

+62
-13
lines changed

autoparallel/api.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import itertools
88
from contextlib import ExitStack
99
from types import MethodType
10-
from typing import Optional
10+
from typing import Optional, Union
1111

1212
import torch
1313
from torch._functorch.aot_autograd import (
14+
JointWithDescriptors,
1415
aot_compile_joint_with_descriptors,
1516
aot_export_joint_with_descriptors,
17+
boxed_nop_preserve_node_meta,
1618
)
19+
from torch._inductor.compile_fx import compile_fx_inner
1720
from torch._inductor.decomposition import select_decomp_table
1821
from torch._inductor.fx_passes.joint_graph import joint_graph_passes
1922
from torch._inductor.fx_passes.post_grad import remove_assert_ops
@@ -23,6 +26,8 @@
2326
from torch.distributed.tensor import DeviceMesh
2427
from torch.export._unlift import _assign_attr
2528
from torch.export.unflatten import _AttrKind
29+
from torch.fx import GraphModule
30+
from torch.fx.experimental._backward_state import BackwardState
2631

2732
from .apply_sharding import apply_sharding_to_model
2833
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
@@ -31,6 +36,40 @@
3136
from .utils import _get_device_from_mesh
3237

3338

39+
def update_joint_with_descriptors(
40+
joint_with_descriptors: JointWithDescriptors,
41+
updated_gm: GraphModule,
42+
) -> None:
43+
"""
44+
Assuming we have transformed updated_gm since the time it was captured,
45+
(e.g. by parallelizing it),
46+
this util updates the joint_with_descriptors struct to reference the new gm, and
47+
updates any copies of tensor meta/shape stored in joint_with_descriptors relating to input arguments,
48+
which may have changed shape since the initial trace.
49+
"""
50+
# TODO: should we upstream a util like this?
51+
placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"]
52+
new_local_args = [n.meta["val"] for n in placeholders]
53+
joint_with_descriptors.graph_module = updated_gm
54+
joint_with_descriptors._aot_graph_capture.graph_module = updated_gm
55+
56+
new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
57+
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
58+
if isinstance(orig, torch.nn.Parameter):
59+
new_flat_args.append(torch.nn.Parameter(new))
60+
else:
61+
new_flat_args.append(new)
62+
63+
tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
64+
new_local_tangents = new_local_args[tangent_idx:]
65+
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
66+
new_flat_args,
67+
new_local_tangents,
68+
)
69+
joint_with_descriptors._aot_state.flat_args = new_flat_args
70+
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents
71+
72+
3473
def _add_alias(gm):
3574
"""
3675
Helper function to add alias nodes to every node in the graph
@@ -163,6 +202,7 @@ def __init__(
163202
input_fn,
164203
mesh: DeviceMesh,
165204
mp_policy: Optional[MixedPrecisionPolicy] = None,
205+
compile: bool = False,
166206
):
167207
self.stack = ExitStack()
168208
self.fake_mode = (
@@ -187,6 +227,7 @@ def __init__(
187227
self.model = move_to_fake(model, self.fake_mode, device)
188228
self.input_fn = input_fn
189229
self.mesh = mesh
230+
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
190231

191232
# NB: rest of the construction happens in __enter__
192233

@@ -196,6 +237,10 @@ def __enter__(self):
196237
assert self.active is False
197238

198239
self.build_model_graph()
240+
self.old_inductor_comprehensive_padding = (
241+
torch._inductor.config.comprehensive_padding
242+
)
243+
torch._inductor.config.comprehensive_padding = False
199244

200245
rescale_grad_comm_cost_for_mp = 1.0
201246
if self.mp_policy is not None:
@@ -224,6 +269,9 @@ def __enter__(self):
224269
return self
225270

226271
def __exit__(self, exc_type, exc_val, exc_tb):
272+
torch._inductor.config.comprehensive_padding = (
273+
self.old_inductor_comprehensive_padding
274+
)
227275
self.active = None
228276
return self.stack.__exit__(exc_type, exc_val, exc_tb)
229277

@@ -252,7 +300,12 @@ def build_model_graph(self):
252300
with set_dtype_cast(True):
253301
ep = torch.export.export(self.model, inputs)
254302
self.joint_with_descriptors = aot_export_joint_with_descriptors(
255-
self.stack, ep.module(), inputs, decompositions=decomp_table
303+
self.stack,
304+
ep.module(),
305+
inputs,
306+
decompositions=decomp_table,
307+
fw_compiler=self.compiler_fn,
308+
bw_compiler=self.compiler_fn,
256309
)
257310
gm = self.joint_with_descriptors.graph_module
258311

@@ -392,8 +445,7 @@ def apply_placement(self, sharding_placement=None):
392445
# parallel_gm, self.params_len, self.buffer_len, self.metadata
393446
# )
394447
self.parallel_gm = parallel_gm
395-
self.joint_with_descriptors.graph_module = parallel_gm
396-
448+
update_joint_with_descriptors(self.joint_with_descriptors, parallel_gm)
397449
# NB: so this function takes in the parameters at the beginning
398450

399451
# let's remove those otherwise we can't clean the backward graph properly

autoparallel/apply_sharding.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,20 @@ def _cleanup_graph(gm):
265265

266266
def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
267267
args = shard_nodes_given_placements(gm, sharding_placement)
268+
local_args = [arg.to_local() for arg in args]
268269

269270
decomp_table = select_decomp_table()
270271
# run with DTensor to apply the collectives given the graph
271272
interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table)
272273

273-
args = [x.to_local() for x in args]
274-
275274
# TODO: make_fx here is suspicious in case of dynamic shapes
276275
with fx_traceback.preserve_node_meta():
277-
parallel_gm0 = make_fx(interp.run)(*args)
276+
parallel_gm0 = make_fx(interp.run)(*local_args)
278277

279278
_cleanup_graph(parallel_gm0)
280279
interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table)
281280
with fx_traceback.preserve_node_meta():
282-
parallel_gm = make_fx(interp2.run)(*args)
281+
parallel_gm = make_fx(interp2.run)(*local_args)
283282
_cleanup_graph(parallel_gm)
284283

285284
# Copy descriptors over to new graph

examples/example_autoparallel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def forward(self, x):
9393

9494

9595
def input_fn():
96+
print(f"global input shape: {(bs, seq_len, dim1)}")
9697
return torch.rand(bs, seq_len, dim1, device="cuda")
9798

9899

@@ -104,10 +105,9 @@ def input_fn():
104105
# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
105106
# mp_policy = None
106107

107-
with AutoParallel(model, input_fn, mesh, mp_policy) as autop:
108+
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
108109
assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes)
109110
assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes)
110-
111111
autop.add_parameter_memory_constraint(low=None, high=None)
112112

113113
x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
@@ -123,8 +123,6 @@ def input_fn():
123123
parallel_mod.to_empty(device="cuda")
124124
parallel_mod.init_weights()
125125

126-
parallel_mod.compile(fullgraph=True)
127-
128126
# now let's run it
129127
x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),)
130128
out = parallel_mod(*x)

examples/example_llama3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def input_fn():
605605
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
606606

607607
# parallelize the model
608-
with AutoParallel(model, input_fn, mesh, mp_policy) as autop:
608+
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
609609
autop.add_parameter_memory_constraint(low=None, high=None)
610610

611611
x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)

0 commit comments

Comments
 (0)