Skip to content

Commit 0cd63db

Browse files
authored
Use inductor decompositions in apply_sharding_to_model (#68)
Inductor requires that a specific set of decompositions are used when compiling a model. We are using a slight different set of decompositions than inductor, so let's convert the final parallelized graph to use the same decompositions as Inductor
1 parent 1436bbc commit 0cd63db

File tree

1 file changed

+58
-3
lines changed

1 file changed

+58
-3
lines changed

autoparallel/apply_sharding.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_named_buffer_nodes,
1515
get_named_param_nodes,
1616
)
17+
from torch._inductor.decomposition import select_decomp_table
1718
from torch._subclasses.fake_tensor import FakeTensor, unset_fake_temporarily
1819
from torch.distributed.tensor import DTensor
1920
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@@ -49,9 +50,12 @@ def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
4950

5051

5152
class ApplyShardingInterpreter(torch.fx.Interpreter):
52-
def __init__(self, module, sharding_placement):
53+
def __init__(self, module, sharding_placement, decomp_table=None):
5354
super().__init__(module, garbage_collect_values=True, graph=None)
5455
self.sharding_placement = sharding_placement
56+
if decomp_table is None:
57+
decomp_table = {}
58+
self.decomp_table = decomp_table
5559

5660
def run_node(self, n):
5761
self._curr_node = n
@@ -161,11 +165,36 @@ def call_function(self, target, args, kwargs):
161165
# TODO: see if we can remove this contiguous properly
162166
new_args[0] = new_args[0].contiguous()
163167

168+
if target in self.decomp_table:
169+
new_target = self.decomp_table[target]
170+
out = super().call_function(new_target, tuple(new_args), kwargs)
171+
# NOTE: is there a canonical way of handling this?
172+
if out is not NotImplemented:
173+
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
174+
return out
164175
out = super().call_function(target, tuple(new_args), kwargs)
165176
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
166177
return out
167178

168179

180+
class ApplyDecompInterpreter(torch.fx.Interpreter):
181+
def __init__(self, module, decomp_table=None):
182+
super().__init__(module, garbage_collect_values=True, graph=None)
183+
if decomp_table is None:
184+
decomp_table = {}
185+
self.decomp_table = decomp_table
186+
187+
def call_function(self, target, args, kwargs):
188+
if target in self.decomp_table:
189+
new_target = self.decomp_table[target]
190+
out = super().call_function(new_target, args, kwargs)
191+
# NOTE: is there a canonical way of handling this?
192+
if out is not NotImplemented:
193+
return out
194+
out = super().call_function(target, args, kwargs)
195+
return out
196+
197+
169198
def shard_node_given_placements(node, sharding_placement, *, meta: bool):
170199
# TODO: not sure if we actually guarantee sharding_placement has ever
171200
# input node lol
@@ -215,17 +244,43 @@ def rename_placeholder_node(
215244
fx_g.graph.erase_node(node)
216245

217246

247+
def _cleanup_graph(gm):
248+
gm.graph.eliminate_dead_code()
249+
gm.recompile()
250+
prev = torch._inductor.config.pattern_matcher
251+
torch._inductor.config.pattern_matcher = False
252+
prev_pass = torch._inductor.config.joint_custom_post_pass
253+
torch._inductor.config.joint_custom_post_pass = None
254+
from torch._inductor.fx_passes.joint_graph import joint_graph_passes
255+
256+
try:
257+
# TODO: Double check if this is what we want to do
258+
gm = joint_graph_passes(gm)
259+
finally:
260+
torch._inductor.config.pattern_matcher = prev
261+
torch._inductor.config.joint_custom_post_pass = prev_pass
262+
gm.graph.eliminate_dead_code()
263+
gm.recompile()
264+
265+
218266
def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
219267
args = shard_nodes_given_placements(gm, sharding_placement)
220268

269+
decomp_table = select_decomp_table()
221270
# run with DTensor to apply the collectives given the graph
222-
interp = ApplyShardingInterpreter(gm, sharding_placement)
271+
interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table)
223272

224273
args = [x.to_local() for x in args]
225274

226275
# TODO: make_fx here is suspicious in case of dynamic shapes
227276
with fx_traceback.preserve_node_meta():
228-
parallel_gm = make_fx(interp.run)(*args)
277+
parallel_gm0 = make_fx(interp.run)(*args)
278+
279+
_cleanup_graph(parallel_gm0)
280+
interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table)
281+
with fx_traceback.preserve_node_meta():
282+
parallel_gm = make_fx(interp2.run)(*args)
283+
_cleanup_graph(parallel_gm)
229284

230285
# Copy descriptors over to new graph
231286
for n1, n2 in zip(

0 commit comments

Comments
 (0)