Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import itertools
from contextlib import ExitStack
from types import MethodType
from typing import Optional
from typing import Optional, Union

import torch
from torch._functorch.aot_autograd import (
JointWithDescriptors,
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
)
Expand All @@ -23,6 +24,8 @@
from torch.distributed.tensor import DeviceMesh
from torch.export._unlift import _assign_attr
from torch.export.unflatten import _AttrKind
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState

from .apply_sharding import apply_sharding_to_model
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
Expand All @@ -31,6 +34,35 @@
from .utils import _get_device_from_mesh


def update_joint_with_descriptors(
joint_with_descriptors: JointWithDescriptors,
parallel_gm: GraphModule,
new_local_args: list[torch.Tensor],
) -> None:
# TODO: should we upstream a util like this?
# Should we pass 'new_local_args' in here?
# or should we just rely on the parallel_gm to have its placehodler metas updated and
# extract the new_local_args from there?
joint_with_descriptors.graph_module = parallel_gm
joint_with_descriptors._aot_graph_capture.graph_module = parallel_gm

new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
if isinstance(orig, torch.nn.Parameter):
new_flat_args.append(torch.nn.Parameter(new))
else:
new_flat_args.append(new)

tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
new_local_tangents = new_local_args[tangent_idx:]
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
new_flat_args,
new_local_tangents,
)
joint_with_descriptors._aot_state.flat_args = new_flat_args
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents


def _add_alias(gm):
"""
Helper function to add alias nodes to every node in the graph
Expand Down Expand Up @@ -369,6 +401,7 @@ def apply_placement(self, sharding_placement=None):
parallel_gm,
sharded_param_dict,
sharded_buffer_dict,
new_local_args,
) = apply_sharding_to_model(
self.gm,
sharding_placement,
Expand All @@ -392,8 +425,9 @@ def apply_placement(self, sharding_placement=None):
# parallel_gm, self.params_len, self.buffer_len, self.metadata
# )
self.parallel_gm = parallel_gm
self.joint_with_descriptors.graph_module = parallel_gm

update_joint_with_descriptors(
self.joint_with_descriptors, parallel_gm, new_local_args
)
# NB: so this function takes in the parameters at the beginning

# let's remove those otherwise we can't clean the backward graph properly
Expand Down
18 changes: 15 additions & 3 deletions autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import tree_flatten, tree_map_only

from .propagation_rules import TENSOR_FACTORY_OPS
Expand Down Expand Up @@ -221,11 +222,22 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
# run with DTensor to apply the collectives given the graph
interp = ApplyShardingInterpreter(gm, sharding_placement)

args = [x.to_local() for x in args]
local_args = []
for placeholder, arg in zip(gm.graph.nodes, args):
assert placeholder.meta["val"].shape == arg.shape
local_arg = arg.to_local()
placeholder.meta["val"] = local_arg
# requires_grad is missing from val and local_arg, take it
# from original tensor_meta
requires_grad = placeholder.meta["tensor_meta"].requires_grad
placeholder.meta["tensor_meta"] = _extract_tensor_metadata(
local_arg.clone().requires_grad_(requires_grad)
)
local_args.append(local_arg)
Comment on lines +225 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong, and we would need to implement this after we construct the parameters, otherwise we end-up sharding the model twice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct. I confirmed that after running make_fx on sharding interpreter, the new parallel_gm0 has its metas updated. I am not sure why i (thought I) had to do this update in hack_aot.


# TODO: make_fx here is suspicious in case of dynamic shapes
with fx_traceback.preserve_node_meta():
parallel_gm = make_fx(interp.run)(*args)
parallel_gm = make_fx(interp.run)(*local_args)

# Copy descriptors over to new graph
for n1, n2 in zip(
Expand Down Expand Up @@ -262,4 +274,4 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
n, sharding_placement, meta=True
)

return parallel_gm, sharded_param_dict, sharded_buffer_dict
return parallel_gm, sharded_param_dict, sharded_buffer_dict, local_args
Loading