From acaaf9cfaeae7976479f60b212408d78454d16a4 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 4 Aug 2025 16:25:53 -0700 Subject: [PATCH] Move compile into AutoParallel API Update node metas after parallelization --- autoparallel/api.py | 60 +++++++++++++++++++++++++++++--- autoparallel/apply_sharding.py | 7 ++-- examples/example_autoparallel.py | 6 ++-- examples/example_llama3.py | 2 +- 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index b24cb611..47baad7d 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -7,13 +7,16 @@ import itertools from contextlib import ExitStack from types import MethodType -from typing import Optional +from typing import Optional, Union import torch from torch._functorch.aot_autograd import ( + JointWithDescriptors, aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, + boxed_nop_preserve_node_meta, ) +from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table from torch._inductor.fx_passes.joint_graph import joint_graph_passes from torch._inductor.fx_passes.post_grad import remove_assert_ops @@ -23,6 +26,8 @@ from torch.distributed.tensor import DeviceMesh from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast @@ -31,6 +36,40 @@ from .utils import _get_device_from_mesh +def update_joint_with_descriptors( + joint_with_descriptors: JointWithDescriptors, + updated_gm: GraphModule, +) -> None: + """ + Assuming we have transformed updated_gm since the time it was captured, + (e.g. by parallelizing it), + this util updates the joint_with_descriptors struct to reference the new gm, and + updates any copies of tensor meta/shape stored in joint_with_descriptors relating to input arguments, + which may have changed shape since the initial trace. + """ + # TODO: should we upstream a util like this? + placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"] + new_local_args = [n.meta["val"] for n in placeholders] + joint_with_descriptors.graph_module = updated_gm + joint_with_descriptors._aot_graph_capture.graph_module = updated_gm + + new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = [] + for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args): + if isinstance(orig, torch.nn.Parameter): + new_flat_args.append(torch.nn.Parameter(new)) + else: + new_flat_args.append(new) + + tangent_idx = len(joint_with_descriptors._aot_state.flat_args) + new_local_tangents = new_local_args[tangent_idx:] + joint_with_descriptors._aot_graph_capture.updated_flat_args = ( + new_flat_args, + new_local_tangents, + ) + joint_with_descriptors._aot_state.flat_args = new_flat_args + joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents + + def _add_alias(gm): """ Helper function to add alias nodes to every node in the graph @@ -163,6 +202,7 @@ def __init__( input_fn, mesh: DeviceMesh, mp_policy: Optional[MixedPrecisionPolicy] = None, + compile: bool = False, ): self.stack = ExitStack() self.fake_mode = ( @@ -187,6 +227,7 @@ def __init__( self.model = move_to_fake(model, self.fake_mode, device) self.input_fn = input_fn self.mesh = mesh + self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta # NB: rest of the construction happens in __enter__ @@ -196,6 +237,10 @@ def __enter__(self): assert self.active is False self.build_model_graph() + self.old_inductor_comprehensive_padding = ( + torch._inductor.config.comprehensive_padding + ) + torch._inductor.config.comprehensive_padding = False rescale_grad_comm_cost_for_mp = 1.0 if self.mp_policy is not None: @@ -224,6 +269,9 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + torch._inductor.config.comprehensive_padding = ( + self.old_inductor_comprehensive_padding + ) self.active = None return self.stack.__exit__(exc_type, exc_val, exc_tb) @@ -252,7 +300,12 @@ def build_model_graph(self): with set_dtype_cast(True): ep = torch.export.export(self.model, inputs) self.joint_with_descriptors = aot_export_joint_with_descriptors( - self.stack, ep.module(), inputs, decompositions=decomp_table + self.stack, + ep.module(), + inputs, + decompositions=decomp_table, + fw_compiler=self.compiler_fn, + bw_compiler=self.compiler_fn, ) gm = self.joint_with_descriptors.graph_module @@ -392,8 +445,7 @@ def apply_placement(self, sharding_placement=None): # parallel_gm, self.params_len, self.buffer_len, self.metadata # ) self.parallel_gm = parallel_gm - self.joint_with_descriptors.graph_module = parallel_gm - + update_joint_with_descriptors(self.joint_with_descriptors, parallel_gm) # NB: so this function takes in the parameters at the beginning # let's remove those otherwise we can't clean the backward graph properly diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 2e12c1d3..151f4bff 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -265,21 +265,20 @@ def _cleanup_graph(gm): def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): args = shard_nodes_given_placements(gm, sharding_placement) + local_args = [arg.to_local() for arg in args] decomp_table = select_decomp_table() # run with DTensor to apply the collectives given the graph interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table) - args = [x.to_local() for x in args] - # TODO: make_fx here is suspicious in case of dynamic shapes with fx_traceback.preserve_node_meta(): - parallel_gm0 = make_fx(interp.run)(*args) + parallel_gm0 = make_fx(interp.run)(*local_args) _cleanup_graph(parallel_gm0) interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table) with fx_traceback.preserve_node_meta(): - parallel_gm = make_fx(interp2.run)(*args) + parallel_gm = make_fx(interp2.run)(*local_args) _cleanup_graph(parallel_gm) # Copy descriptors over to new graph diff --git a/examples/example_autoparallel.py b/examples/example_autoparallel.py index ff9eff2a..5c56ce5d 100644 --- a/examples/example_autoparallel.py +++ b/examples/example_autoparallel.py @@ -93,6 +93,7 @@ def forward(self, x): def input_fn(): + print(f"global input shape: {(bs, seq_len, dim1)}") return torch.rand(bs, seq_len, dim1, device="cuda") @@ -104,10 +105,9 @@ def input_fn(): # mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) # mp_policy = None -with AutoParallel(model, input_fn, mesh, mp_policy) as autop: +with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop: assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes) assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes) - autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) @@ -123,8 +123,6 @@ def input_fn(): parallel_mod.to_empty(device="cuda") parallel_mod.init_weights() -parallel_mod.compile(fullgraph=True) - # now let's run it x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),) out = parallel_mod(*x) diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 36e64795..47a2b1e8 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -605,7 +605,7 @@ def input_fn(): mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) # parallelize the model -with AutoParallel(model, input_fn, mesh, mp_policy) as autop: +with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop: autop.add_parameter_memory_constraint(low=None, high=None) x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)