diff --git a/autoparallel/api.py b/autoparallel/api.py index b24cb611..304c2e20 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -7,10 +7,11 @@ import itertools from contextlib import ExitStack from types import MethodType -from typing import Optional +from typing import Optional, Union import torch from torch._functorch.aot_autograd import ( + JointWithDescriptors, aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, ) @@ -23,6 +24,8 @@ from torch.distributed.tensor import DeviceMesh from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast @@ -31,6 +34,35 @@ from .utils import _get_device_from_mesh +def update_joint_with_descriptors( + joint_with_descriptors: JointWithDescriptors, + parallel_gm: GraphModule, + new_local_args: list[torch.Tensor], +) -> None: + # TODO: should we upstream a util like this? + # Should we pass 'new_local_args' in here? + # or should we just rely on the parallel_gm to have its placehodler metas updated and + # extract the new_local_args from there? + joint_with_descriptors.graph_module = parallel_gm + joint_with_descriptors._aot_graph_capture.graph_module = parallel_gm + + new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = [] + for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args): + if isinstance(orig, torch.nn.Parameter): + new_flat_args.append(torch.nn.Parameter(new)) + else: + new_flat_args.append(new) + + tangent_idx = len(joint_with_descriptors._aot_state.flat_args) + new_local_tangents = new_local_args[tangent_idx:] + joint_with_descriptors._aot_graph_capture.updated_flat_args = ( + new_flat_args, + new_local_tangents, + ) + joint_with_descriptors._aot_state.flat_args = new_flat_args + joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents + + def _add_alias(gm): """ Helper function to add alias nodes to every node in the graph @@ -369,6 +401,7 @@ def apply_placement(self, sharding_placement=None): parallel_gm, sharded_param_dict, sharded_buffer_dict, + new_local_args, ) = apply_sharding_to_model( self.gm, sharding_placement, @@ -392,8 +425,9 @@ def apply_placement(self, sharding_placement=None): # parallel_gm, self.params_len, self.buffer_len, self.metadata # ) self.parallel_gm = parallel_gm - self.joint_with_descriptors.graph_module = parallel_gm - + update_joint_with_descriptors( + self.joint_with_descriptors, parallel_gm, new_local_args + ) # NB: so this function takes in the parameters at the beginning # let's remove those otherwise we can't clean the backward graph properly diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 19b35d00..d8bab0a5 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -20,6 +20,7 @@ from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.utils._pytree import tree_flatten, tree_map_only from .propagation_rules import TENSOR_FACTORY_OPS @@ -221,11 +222,22 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): # run with DTensor to apply the collectives given the graph interp = ApplyShardingInterpreter(gm, sharding_placement) - args = [x.to_local() for x in args] + local_args = [] + for placeholder, arg in zip(gm.graph.nodes, args): + assert placeholder.meta["val"].shape == arg.shape + local_arg = arg.to_local() + placeholder.meta["val"] = local_arg + # requires_grad is missing from val and local_arg, take it + # from original tensor_meta + requires_grad = placeholder.meta["tensor_meta"].requires_grad + placeholder.meta["tensor_meta"] = _extract_tensor_metadata( + local_arg.clone().requires_grad_(requires_grad) + ) + local_args.append(local_arg) # TODO: make_fx here is suspicious in case of dynamic shapes with fx_traceback.preserve_node_meta(): - parallel_gm = make_fx(interp.run)(*args) + parallel_gm = make_fx(interp.run)(*local_args) # Copy descriptors over to new graph for n1, n2 in zip( @@ -262,4 +274,4 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): n, sharding_placement, meta=True ) - return parallel_gm, sharded_param_dict, sharded_buffer_dict + return parallel_gm, sharded_param_dict, sharded_buffer_dict, local_args