Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import itertools
from contextlib import ExitStack
from types import MethodType
from typing import Optional
from typing import Optional, Union

import torch
from torch._functorch.aot_autograd import (
JointWithDescriptors,
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
boxed_nop_preserve_node_meta,
)
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.fx_passes.joint_graph import joint_graph_passes
from torch._inductor.fx_passes.post_grad import remove_assert_ops
Expand All @@ -23,6 +26,8 @@
from torch.distributed.tensor import DeviceMesh
from torch.export._unlift import _assign_attr
from torch.export.unflatten import _AttrKind
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState

from .apply_sharding import apply_sharding_to_model
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
Expand All @@ -31,6 +36,40 @@
from .utils import _get_device_from_mesh


def update_joint_with_descriptors(
joint_with_descriptors: JointWithDescriptors,
updated_gm: GraphModule,
) -> None:
"""
Assuming we have transformed updated_gm since the time it was captured,
(e.g. by parallelizing it),
this util updates the joint_with_descriptors struct to reference the new gm, and
updates any copies of tensor meta/shape stored in joint_with_descriptors relating to input arguments,
which may have changed shape since the initial trace.
"""
# TODO: should we upstream a util like this?
placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"]
new_local_args = [n.meta["val"] for n in placeholders]
joint_with_descriptors.graph_module = updated_gm
joint_with_descriptors._aot_graph_capture.graph_module = updated_gm

new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = []
for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args):
if isinstance(orig, torch.nn.Parameter):
new_flat_args.append(torch.nn.Parameter(new))
else:
new_flat_args.append(new)

tangent_idx = len(joint_with_descriptors._aot_state.flat_args)
new_local_tangents = new_local_args[tangent_idx:]
joint_with_descriptors._aot_graph_capture.updated_flat_args = (
new_flat_args,
new_local_tangents,
)
joint_with_descriptors._aot_state.flat_args = new_flat_args
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents


def _add_alias(gm):
"""
Helper function to add alias nodes to every node in the graph
Expand Down Expand Up @@ -163,6 +202,7 @@ def __init__(
input_fn,
mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy] = None,
compile: bool = False,
):
self.stack = ExitStack()
self.fake_mode = (
Expand All @@ -187,6 +227,7 @@ def __init__(
self.model = move_to_fake(model, self.fake_mode, device)
self.input_fn = input_fn
self.mesh = mesh
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta

# NB: rest of the construction happens in __enter__

Expand All @@ -196,6 +237,10 @@ def __enter__(self):
assert self.active is False

self.build_model_graph()
self.old_inductor_comprehensive_padding = (
torch._inductor.config.comprehensive_padding
)
torch._inductor.config.comprehensive_padding = False

rescale_grad_comm_cost_for_mp = 1.0
if self.mp_policy is not None:
Expand Down Expand Up @@ -224,6 +269,9 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
torch._inductor.config.comprehensive_padding = (
self.old_inductor_comprehensive_padding
)
self.active = None
return self.stack.__exit__(exc_type, exc_val, exc_tb)

Expand Down Expand Up @@ -252,7 +300,12 @@ def build_model_graph(self):
with set_dtype_cast(True):
ep = torch.export.export(self.model, inputs)
self.joint_with_descriptors = aot_export_joint_with_descriptors(
self.stack, ep.module(), inputs, decompositions=decomp_table
self.stack,
ep.module(),
inputs,
decompositions=decomp_table,
fw_compiler=self.compiler_fn,
bw_compiler=self.compiler_fn,
)
gm = self.joint_with_descriptors.graph_module

Expand Down Expand Up @@ -392,8 +445,7 @@ def apply_placement(self, sharding_placement=None):
# parallel_gm, self.params_len, self.buffer_len, self.metadata
# )
self.parallel_gm = parallel_gm
self.joint_with_descriptors.graph_module = parallel_gm

update_joint_with_descriptors(self.joint_with_descriptors, parallel_gm)
# NB: so this function takes in the parameters at the beginning

# let's remove those otherwise we can't clean the backward graph properly
Expand Down
7 changes: 3 additions & 4 deletions autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,20 @@ def _cleanup_graph(gm):

def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
args = shard_nodes_given_placements(gm, sharding_placement)
local_args = [arg.to_local() for arg in args]

decomp_table = select_decomp_table()
# run with DTensor to apply the collectives given the graph
interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table)

args = [x.to_local() for x in args]

# TODO: make_fx here is suspicious in case of dynamic shapes
with fx_traceback.preserve_node_meta():
parallel_gm0 = make_fx(interp.run)(*args)
parallel_gm0 = make_fx(interp.run)(*local_args)

_cleanup_graph(parallel_gm0)
interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table)
with fx_traceback.preserve_node_meta():
parallel_gm = make_fx(interp2.run)(*args)
parallel_gm = make_fx(interp2.run)(*local_args)
_cleanup_graph(parallel_gm)

# Copy descriptors over to new graph
Expand Down
6 changes: 2 additions & 4 deletions examples/example_autoparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def forward(self, x):


def input_fn():
print(f"global input shape: {(bs, seq_len, dim1)}")
return torch.rand(bs, seq_len, dim1, device="cuda")


Expand All @@ -104,10 +105,9 @@ def input_fn():
# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
# mp_policy = None

with AutoParallel(model, input_fn, mesh, mp_policy) as autop:
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes)
assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes)

autop.add_parameter_memory_constraint(low=None, high=None)

x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
Expand All @@ -123,8 +123,6 @@ def input_fn():
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()

parallel_mod.compile(fullgraph=True)

# now let's run it
x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),)
out = parallel_mod(*x)
Expand Down
2 changes: 1 addition & 1 deletion examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def input_fn():
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)

# parallelize the model
with AutoParallel(model, input_fn, mesh, mp_policy) as autop:
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1)
Expand Down