77import itertools
88from contextlib import ExitStack
99from types import MethodType
10- from typing import Optional
10+ from typing import Optional , Union
1111
1212import torch
1313from torch ._functorch .aot_autograd import (
14+ JointWithDescriptors ,
1415 aot_compile_joint_with_descriptors ,
1516 aot_export_joint_with_descriptors ,
17+ boxed_nop_preserve_node_meta ,
1618)
19+ from torch ._inductor .compile_fx import compile_fx_inner
1720from torch ._inductor .decomposition import select_decomp_table
1821from torch ._inductor .fx_passes .joint_graph import joint_graph_passes
1922from torch ._inductor .fx_passes .post_grad import remove_assert_ops
2326from torch .distributed .tensor import DeviceMesh
2427from torch .export ._unlift import _assign_attr
2528from torch .export .unflatten import _AttrKind
29+ from torch .fx import GraphModule
30+ from torch .fx .experimental ._backward_state import BackwardState
2631
2732from .apply_sharding import apply_sharding_to_model
2833from .cast_parametrization import apply_dtype_cast , canonicalize_mp , set_dtype_cast
3136from .utils import _get_device_from_mesh
3237
3338
39+ def update_joint_with_descriptors (
40+ joint_with_descriptors : JointWithDescriptors ,
41+ updated_gm : GraphModule ,
42+ ) -> None :
43+ """
44+ Assuming we have transformed updated_gm since the time it was captured,
45+ (e.g. by parallelizing it),
46+ this util updates the joint_with_descriptors struct to reference the new gm, and
47+ updates any copies of tensor meta/shape stored in joint_with_descriptors relating to input arguments,
48+ which may have changed shape since the initial trace.
49+ """
50+ # TODO: should we upstream a util like this?
51+ placeholders = [n for n in updated_gm .graph .nodes if n .op == "placeholder" ]
52+ new_local_args = [n .meta ["val" ] for n in placeholders ]
53+ joint_with_descriptors .graph_module = updated_gm
54+ joint_with_descriptors ._aot_graph_capture .graph_module = updated_gm
55+
56+ new_flat_args : list [Union [torch .Tensor , int , torch .SymInt , BackwardState ]] = []
57+ for orig , new in zip (joint_with_descriptors ._aot_state .flat_args , new_local_args ):
58+ if isinstance (orig , torch .nn .Parameter ):
59+ new_flat_args .append (torch .nn .Parameter (new ))
60+ else :
61+ new_flat_args .append (new )
62+
63+ tangent_idx = len (joint_with_descriptors ._aot_state .flat_args )
64+ new_local_tangents = new_local_args [tangent_idx :]
65+ joint_with_descriptors ._aot_graph_capture .updated_flat_args = (
66+ new_flat_args ,
67+ new_local_tangents ,
68+ )
69+ joint_with_descriptors ._aot_state .flat_args = new_flat_args
70+ joint_with_descriptors ._aot_state .fw_metadata .traced_tangents = new_local_tangents
71+
72+
3473def _add_alias (gm ):
3574 """
3675 Helper function to add alias nodes to every node in the graph
@@ -163,6 +202,7 @@ def __init__(
163202 input_fn ,
164203 mesh : DeviceMesh ,
165204 mp_policy : Optional [MixedPrecisionPolicy ] = None ,
205+ compile : bool = False ,
166206 ):
167207 self .stack = ExitStack ()
168208 self .fake_mode = (
@@ -187,6 +227,7 @@ def __init__(
187227 self .model = move_to_fake (model , self .fake_mode , device )
188228 self .input_fn = input_fn
189229 self .mesh = mesh
230+ self .compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
190231
191232 # NB: rest of the construction happens in __enter__
192233
@@ -196,6 +237,10 @@ def __enter__(self):
196237 assert self .active is False
197238
198239 self .build_model_graph ()
240+ self .old_inductor_comprehensive_padding = (
241+ torch ._inductor .config .comprehensive_padding
242+ )
243+ torch ._inductor .config .comprehensive_padding = False
199244
200245 rescale_grad_comm_cost_for_mp = 1.0
201246 if self .mp_policy is not None :
@@ -224,6 +269,9 @@ def __enter__(self):
224269 return self
225270
226271 def __exit__ (self , exc_type , exc_val , exc_tb ):
272+ torch ._inductor .config .comprehensive_padding = (
273+ self .old_inductor_comprehensive_padding
274+ )
227275 self .active = None
228276 return self .stack .__exit__ (exc_type , exc_val , exc_tb )
229277
@@ -252,7 +300,12 @@ def build_model_graph(self):
252300 with set_dtype_cast (True ):
253301 ep = torch .export .export (self .model , inputs )
254302 self .joint_with_descriptors = aot_export_joint_with_descriptors (
255- self .stack , ep .module (), inputs , decompositions = decomp_table
303+ self .stack ,
304+ ep .module (),
305+ inputs ,
306+ decompositions = decomp_table ,
307+ fw_compiler = self .compiler_fn ,
308+ bw_compiler = self .compiler_fn ,
256309 )
257310 gm = self .joint_with_descriptors .graph_module
258311
@@ -392,8 +445,7 @@ def apply_placement(self, sharding_placement=None):
392445 # parallel_gm, self.params_len, self.buffer_len, self.metadata
393446 # )
394447 self .parallel_gm = parallel_gm
395- self .joint_with_descriptors .graph_module = parallel_gm
396-
448+ update_joint_with_descriptors (self .joint_with_descriptors , parallel_gm )
397449 # NB: so this function takes in the parameters at the beginning
398450
399451 # let's remove those otherwise we can't clean the backward graph properly
0 commit comments