77import itertools
88from contextlib import ExitStack
99from types import MethodType
10- from typing import Optional
10+ from typing import Optional , Union
1111
1212import torch
1313from torch ._functorch .aot_autograd import (
14+ JointWithDescriptors ,
1415 aot_compile_joint_with_descriptors ,
1516 aot_export_joint_with_descriptors ,
1617)
2324from torch .distributed .tensor import DeviceMesh
2425from torch .export ._unlift import _assign_attr
2526from torch .export .unflatten import _AttrKind
27+ from torch .fx import GraphModule
28+ from torch .fx .experimental ._backward_state import BackwardState
2629
2730from .apply_sharding import apply_sharding_to_model
2831from .cast_parametrization import apply_dtype_cast , canonicalize_mp , set_dtype_cast
3134from .utils import _get_device_from_mesh
3235
3336
37+ def update_joint_with_descriptors (
38+ joint_with_descriptors : JointWithDescriptors ,
39+ parallel_gm : GraphModule ,
40+ new_local_args : list [torch .Tensor ],
41+ ) -> None :
42+ # TODO: should we upstream a util like this?
43+ # Should we pass 'new_local_args' in here?
44+ # or should we just rely on the parallel_gm to have its placehodler metas updated and
45+ # extract the new_local_args from there?
46+ joint_with_descriptors .graph_module = parallel_gm
47+ joint_with_descriptors ._aot_graph_capture .graph_module = parallel_gm
48+
49+ new_flat_args : list [Union [torch .Tensor , int , torch .SymInt , BackwardState ]] = []
50+ for orig , new in zip (joint_with_descriptors ._aot_state .flat_args , new_local_args ):
51+ if isinstance (orig , torch .nn .Parameter ):
52+ new_flat_args .append (torch .nn .Parameter (new ))
53+ else :
54+ new_flat_args .append (new )
55+
56+ tangent_idx = len (joint_with_descriptors ._aot_state .flat_args )
57+ new_local_tangents = new_local_args [tangent_idx :]
58+ joint_with_descriptors ._aot_graph_capture .updated_flat_args = (
59+ new_flat_args ,
60+ new_local_tangents ,
61+ )
62+ joint_with_descriptors ._aot_state .flat_args = new_flat_args
63+ joint_with_descriptors ._aot_state .fw_metadata .traced_tangents = new_local_tangents
64+
65+
3466def _add_alias (gm ):
3567 """
3668 Helper function to add alias nodes to every node in the graph
@@ -369,6 +401,7 @@ def apply_placement(self, sharding_placement=None):
369401 parallel_gm ,
370402 sharded_param_dict ,
371403 sharded_buffer_dict ,
404+ new_local_args ,
372405 ) = apply_sharding_to_model (
373406 self .gm ,
374407 sharding_placement ,
@@ -392,8 +425,9 @@ def apply_placement(self, sharding_placement=None):
392425 # parallel_gm, self.params_len, self.buffer_len, self.metadata
393426 # )
394427 self .parallel_gm = parallel_gm
395- self .joint_with_descriptors .graph_module = parallel_gm
396-
428+ update_joint_with_descriptors (
429+ self .joint_with_descriptors , parallel_gm , new_local_args
430+ )
397431 # NB: so this function takes in the parameters at the beginning
398432
399433 # let's remove those otherwise we can't clean the backward graph properly
0 commit comments