Skip to content

Commit debe2a3

Browse files
committed
Update node metas after parallelization
TODO: test this somehow
1 parent badffa7 commit debe2a3

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

autoparallel/api.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import itertools
88
from contextlib import ExitStack
99
from types import MethodType
10-
from typing import Optional
10+
from typing import Optional, Union
1111

1212
import torch
1313
from torch._functorch.aot_autograd import (
14+
JointWithDescriptors,
1415
aot_compile_joint_with_descriptors,
1516
aot_export_joint_with_descriptors,
1617
)
@@ -23,6 +24,8 @@
2324
from torch.distributed.tensor import DeviceMesh
2425
from torch.export._unlift import _assign_attr
2526
from torch.export.unflatten import _AttrKind
27+
from torch.fx import GraphModule
28+
from torch.fx.experimental._backward_state import BackwardState
2629

2730
from .apply_sharding import apply_sharding_to_model
2831
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
@@ -31,6 +34,35 @@
3134
from .utils import _get_device_from_mesh
3235

3336

37+
def update_joint_with_descriptors(
38+
joint_with_descriptors: JointWithDescriptors,
39+
parallel_gm: GraphModule,
40+
new_local_args: list[torch.Tensor],
41+
) -> None:
42+
# TODO: should we upstream a util like this?
43+
# Should we pass 'new_local_args' in here?
44+
# or should we just rely on the parallel_gm to have its placehodler metas updated and
45+
# extract the new_local_args from there?
46+
joint_with_descriptors.graph_module = parallel_gm
47+
joint_with_descriptors._aot_graph_capture.graph_module = parallel_gm
48+
49+
new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
50+
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
51+
if isinstance(orig, torch.nn.Parameter):
52+
new_flat_args.append(torch.nn.Parameter(new))
53+
else:
54+
new_flat_args.append(new)
55+
56+
tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
57+
new_local_tangents = new_local_args[tangent_idx:]
58+
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
59+
new_flat_args,
60+
new_local_tangents,
61+
)
62+
joint_with_descriptors._aot_state.flat_args = new_flat_args
63+
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents
64+
65+
3466
def _add_alias(gm):
3567
"""
3668
Helper function to add alias nodes to every node in the graph
@@ -369,6 +401,7 @@ def apply_placement(self, sharding_placement=None):
369401
parallel_gm,
370402
sharded_param_dict,
371403
sharded_buffer_dict,
404+
new_local_args,
372405
) = apply_sharding_to_model(
373406
self.gm,
374407
sharding_placement,
@@ -392,8 +425,9 @@ def apply_placement(self, sharding_placement=None):
392425
# parallel_gm, self.params_len, self.buffer_len, self.metadata
393426
# )
394427
self.parallel_gm = parallel_gm
395-
self.joint_with_descriptors.graph_module = parallel_gm
396-
428+
update_joint_with_descriptors(
429+
self.joint_with_descriptors, parallel_gm, new_local_args
430+
)
397431
# NB: so this function takes in the parameters at the beginning
398432

399433
# let's remove those otherwise we can't clean the backward graph properly

autoparallel/apply_sharding.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2121
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa
2222
from torch.fx.experimental.proxy_tensor import make_fx
23+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
2324
from torch.utils._pytree import tree_flatten, tree_map_only
2425

2526
from .propagation_rules import TENSOR_FACTORY_OPS
@@ -221,11 +222,22 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
221222
# run with DTensor to apply the collectives given the graph
222223
interp = ApplyShardingInterpreter(gm, sharding_placement)
223224

224-
args = [x.to_local() for x in args]
225+
local_args = []
226+
for placeholder, arg in zip(gm.graph.nodes, args):
227+
assert placeholder.meta["val"].shape == arg.shape
228+
local_arg = arg.to_local()
229+
placeholder.meta["val"] = local_arg
230+
# requires_grad is missing from val and local_arg, take it
231+
# from original tensor_meta
232+
requires_grad = placeholder.meta["tensor_meta"].requires_grad
233+
placeholder.meta["tensor_meta"] = _extract_tensor_metadata(
234+
local_arg.clone().requires_grad_(requires_grad)
235+
)
236+
local_args.append(local_arg)
225237

226238
# TODO: make_fx here is suspicious in case of dynamic shapes
227239
with fx_traceback.preserve_node_meta():
228-
parallel_gm = make_fx(interp.run)(*args)
240+
parallel_gm = make_fx(interp.run)(*local_args)
229241

230242
# Copy descriptors over to new graph
231243
for n1, n2 in zip(
@@ -262,4 +274,4 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
262274
n, sharding_placement, meta=True
263275
)
264276

265-
return parallel_gm, sharded_param_dict, sharded_buffer_dict
277+
return parallel_gm, sharded_param_dict, sharded_buffer_dict, local_args

0 commit comments

Comments
 (0)