|
14 | 14 | get_named_buffer_nodes, |
15 | 15 | get_named_param_nodes, |
16 | 16 | ) |
| 17 | +from torch._inductor.decomposition import select_decomp_table |
17 | 18 | from torch._subclasses.fake_tensor import FakeTensor, unset_fake_temporarily |
18 | 19 | from torch.distributed.tensor import DTensor |
19 | 20 | from torch.distributed.tensor._dtensor_spec import DTensorSpec |
@@ -49,9 +50,12 @@ def my_redistribute_local_tensor(arg, curr_spec, tgt_spec): |
49 | 50 |
|
50 | 51 |
|
51 | 52 | class ApplyShardingInterpreter(torch.fx.Interpreter): |
52 | | - def __init__(self, module, sharding_placement): |
| 53 | + def __init__(self, module, sharding_placement, decomp_table=None): |
53 | 54 | super().__init__(module, garbage_collect_values=True, graph=None) |
54 | 55 | self.sharding_placement = sharding_placement |
| 56 | + if decomp_table is None: |
| 57 | + decomp_table = {} |
| 58 | + self.decomp_table = decomp_table |
55 | 59 |
|
56 | 60 | def run_node(self, n): |
57 | 61 | self._curr_node = n |
@@ -161,11 +165,36 @@ def call_function(self, target, args, kwargs): |
161 | 165 | # TODO: see if we can remove this contiguous properly |
162 | 166 | new_args[0] = new_args[0].contiguous() |
163 | 167 |
|
| 168 | + if target in self.decomp_table: |
| 169 | + new_target = self.decomp_table[target] |
| 170 | + out = super().call_function(new_target, tuple(new_args), kwargs) |
| 171 | + # NOTE: is there a canonical way of handling this? |
| 172 | + if out is not NotImplemented: |
| 173 | + out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
| 174 | + return out |
164 | 175 | out = super().call_function(target, tuple(new_args), kwargs) |
165 | 176 | out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
166 | 177 | return out |
167 | 178 |
|
168 | 179 |
|
| 180 | +class ApplyDecompInterpreter(torch.fx.Interpreter): |
| 181 | + def __init__(self, module, decomp_table=None): |
| 182 | + super().__init__(module, garbage_collect_values=True, graph=None) |
| 183 | + if decomp_table is None: |
| 184 | + decomp_table = {} |
| 185 | + self.decomp_table = decomp_table |
| 186 | + |
| 187 | + def call_function(self, target, args, kwargs): |
| 188 | + if target in self.decomp_table: |
| 189 | + new_target = self.decomp_table[target] |
| 190 | + out = super().call_function(new_target, args, kwargs) |
| 191 | + # NOTE: is there a canonical way of handling this? |
| 192 | + if out is not NotImplemented: |
| 193 | + return out |
| 194 | + out = super().call_function(target, args, kwargs) |
| 195 | + return out |
| 196 | + |
| 197 | + |
169 | 198 | def shard_node_given_placements(node, sharding_placement, *, meta: bool): |
170 | 199 | # TODO: not sure if we actually guarantee sharding_placement has ever |
171 | 200 | # input node lol |
@@ -215,17 +244,43 @@ def rename_placeholder_node( |
215 | 244 | fx_g.graph.erase_node(node) |
216 | 245 |
|
217 | 246 |
|
| 247 | +def _cleanup_graph(gm): |
| 248 | + gm.graph.eliminate_dead_code() |
| 249 | + gm.recompile() |
| 250 | + prev = torch._inductor.config.pattern_matcher |
| 251 | + torch._inductor.config.pattern_matcher = False |
| 252 | + prev_pass = torch._inductor.config.joint_custom_post_pass |
| 253 | + torch._inductor.config.joint_custom_post_pass = None |
| 254 | + from torch._inductor.fx_passes.joint_graph import joint_graph_passes |
| 255 | + |
| 256 | + try: |
| 257 | + # TODO: Double check if this is what we want to do |
| 258 | + gm = joint_graph_passes(gm) |
| 259 | + finally: |
| 260 | + torch._inductor.config.pattern_matcher = prev |
| 261 | + torch._inductor.config.joint_custom_post_pass = prev_pass |
| 262 | + gm.graph.eliminate_dead_code() |
| 263 | + gm.recompile() |
| 264 | + |
| 265 | + |
218 | 266 | def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): |
219 | 267 | args = shard_nodes_given_placements(gm, sharding_placement) |
220 | 268 |
|
| 269 | + decomp_table = select_decomp_table() |
221 | 270 | # run with DTensor to apply the collectives given the graph |
222 | | - interp = ApplyShardingInterpreter(gm, sharding_placement) |
| 271 | + interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table) |
223 | 272 |
|
224 | 273 | args = [x.to_local() for x in args] |
225 | 274 |
|
226 | 275 | # TODO: make_fx here is suspicious in case of dynamic shapes |
227 | 276 | with fx_traceback.preserve_node_meta(): |
228 | | - parallel_gm = make_fx(interp.run)(*args) |
| 277 | + parallel_gm0 = make_fx(interp.run)(*args) |
| 278 | + |
| 279 | + _cleanup_graph(parallel_gm0) |
| 280 | + interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table) |
| 281 | + with fx_traceback.preserve_node_meta(): |
| 282 | + parallel_gm = make_fx(interp2.run)(*args) |
| 283 | + _cleanup_graph(parallel_gm) |
229 | 284 |
|
230 | 285 | # Copy descriptors over to new graph |
231 | 286 | for n1, n2 in zip( |
|
0 commit comments