Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 308 additions & 0 deletions autoparallel/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional

import torch
from torch._functorch.partitioners import _has_tag_is_backward, _size_of
from torch.utils._ordered_set import OrderedSet
from torch.utils.checkpoint import CheckpointPolicy

logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# reimplement torch._functorch.partitioners.must_recompute
# to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE
# For now there isn't any distinction in the partitioner between both
# and I think this is a bug
def must_recompute(node: torch.fx.Node) -> bool:
return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE


def is_graph_input(node: torch.fx.Node) -> bool:
return node.op == "placeholder"


def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)


def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)


def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
# maybe just 2D FSDP
# ag_node = node.args[0]
# assert is_graph_input(ag_node.args[0]) or (
# ag_node.args[0].op == "call_function"
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
# and is_graph_input(ag_node.args[0].args[0])
# ), (
# "Assume all_gather_into_tensor input is either graph input "
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
# )
Comment on lines +47 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably will want to check if we want to fix this, or just always recompute all-gathers in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, would it ever be profitable to recompute an allgather for e.g. context parallelism? And would it ever be better to not recompute an allgather for fsdp?

If we want to make this all about FSDP for now, i think we could add some tagging to the nodes when we do the my_redistribute_tensor for param nodes and snapshot the fact that it's for FSDP.

return True
return False


# mypy: ignore-errors


def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None:
"""
Force recompute all_gather nodes from simple fsdp in the graph.
This pass should be added in torch._inductor.config.joint_custom_post_pass
"""

def force_recompute_node(node):
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
if "ac_graph_id" not in node.meta:
# ac_graph_id is used in the partitioner to decide
# if two nodes which have AC applied come from a different
# AC regions. This is needed because nodes in the boundary
# of two AC regions are marked as MUST_SAVE. In our case
# we just add a large value of ac_graph_id so that
# all nodes we tag for recomputation do indeed get recomputed
# and are not influenced by other nodes in the graph with
# nearby ac_graph_id values
node.meta["ac_graph_id"] = 100000

# Make all-gather nodes (and related nodes) recomputable, to circumvent
# https://github.com/pytorch/pytorch/issues/136433
for node in graph.nodes:
if is_wait_tensor_from_fsdp(node):
ag_node = node.args[0]
force_recompute_node(ag_node) # all_gather
force_recompute_node(node) # wait_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm if I'm following, right now the logic here kind of needs to mirror the logic in is_wait_tensor_from_fsdp - the first function attempts to looks for allgathers that are specifically "FSDP" allgathers (the allgather is done directly on a param, and or the result of a downcast if using AMP), and if that condition is true we try to mark the nodes as must recompute.

It might be cleaner just to "force recompute every node in the chain between the current wait tensor and its param" here, that way we don't have to worry about this code and the code in is_wait_tensor_from_fsdp going out of sync. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it would be a good approach to improve this part. Ideally though we would need to have some heuristic like all the downstream users of the parameter up until the wait only have a single input node, otherwise we might end-up mixing two-operand operations which are not from FSDP-like ops.

# Force-recompute slice that comes after wait
for user in node.users:
if (
user.op == "call_function"
and user.target == torch.ops.aten.slice.Tensor
):
force_recompute_node(user)
# Force-recompute potential dtype casts from all_gather
if (
ag_node.all_input_nodes[0].op == "call_function"
and ag_node.args[0].target
== torch.ops.prims.convert_element_type.default
):
force_recompute_node(ag_node.all_input_nodes[0])


def mark_nodes_as_must_save_to_stage_recomputation(
joint_graph: torch.fx.Graph,
stage_size_in_GiB: Optional[float] = None,
) -> None:
"""
Marks specific nodes as "must save" to optimize memory usage during recomputation.
With aggressive recomputation strategies, we often encounter situations where long chains
of forward nodes must be recomputed before executing backward pass nodes, causing high
peak memory usage. This function breaks these recomputation chains into smaller stages
based by periodically saving itermediate nodes, keeping peak memory usage below.
Args:
joint_graph: The joint graph containing both forward and backward nodes
stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this is missing the option to disable the AutoAC. Before -1 would early return, now it just errors?

"""
INT_INF = int(1e9)

def get_required_fwd_nodes(
joint_graph: torch.fx.Graph,
) -> OrderedSet[torch.fx.Node]:
"""
Return the set of nodes that are required in the forward graph.
NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py
where nodes are classified into three types -- fwd, bwd, and unclaimed
both bwd and unclaimed nodes have partitioner_tag equal to "is_backward"
"""
required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
for node in joint_graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
continue
if node.op == "output":
continue
if _has_tag_is_backward(node):
continue
required_fwd_nodes.add(node)
return required_fwd_nodes

def get_node_distance_to_bwd(
joint_graph: torch.fx.Graph,
get_required_fwd_nodes: OrderedSet[torch.fx.Node],
) -> dict[torch.fx.Node, int]:
"""
Compute and return the distance of all nodes to the closest backward node.
If a node is not an ancestor of a backward node, then its distance is INT_INF.
NOTE: this is adapted from
https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097
"""
dist_from_bw = {}
for node in reversed(joint_graph.nodes):
if node.op == "output":
dist_from_bw[node] = INT_INF
elif node not in get_required_fwd_nodes:
dist_from_bw[node] = 0
else:
dist_from_bw[node] = INT_INF
for user in node.users:
dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1)
return dist_from_bw

def get_all_recomputable_forward_nodes(
joint_graph: torch.fx.Graph,
) -> OrderedSet[torch.fx.Node]:
"""
Return the set of all forward nodes that are recomputable
"""
required_fwd_nodes = get_required_fwd_nodes(joint_graph)
dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes)
fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
for node in joint_graph.nodes:
if (
node in required_fwd_nodes
and dist_from_bw[node] < INT_INF
and node.op != "placeholder"
):
fwd_recomputable_nodes.add(node)
return fwd_recomputable_nodes

def mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None:
"""
Given a list of nodes, mark them as must save.
"""
print(f"mark_nodes_as_must_save: {must_save_nodes}")
for node in must_save_nodes:
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE

fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph)

# Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below
for node in fwd_recomputable_nodes:
if node.meta.get("recompute", None) is not None:
# do not mess with allgather nodes that have already been marked recompute!
continue
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
# add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine
# and is the same we add for the all-gather nodes
node.meta["ac_graph_id"] = 100000

# get the mapping between node name and node
name_to_node_mapping = {}
for node in fwd_recomputable_nodes:
name_to_node_mapping[node.name] = node

# populate node_to_predecessors, accounting for must_recompute nodes. In particular,
# if a node is marked as must recompute, then for its users, their predecessors should
# be updated to be instead the predecessors of the must recompute node.
node_to_predecessors = defaultdict(OrderedSet)
for node in fwd_recomputable_nodes:
node_to_predecessors[node] = OrderedSet(
[pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes]
)
for node in fwd_recomputable_nodes:
if not must_recompute(node):
continue
for user in node.users:
if user in fwd_recomputable_nodes:
node_to_predecessors[user].remove(node)
node_to_predecessors[user].update(node_to_predecessors[node])

# populate node_to_last_usage
# if A is last used by B, then A \in node_to_last_usage[B]
node_to_last_usage = defaultdict(OrderedSet)
last_used_by = {}
for node in fwd_recomputable_nodes:
last_used_by[node] = node
for pred in node_to_predecessors[node]:
last_used_by[pred] = node
for producer, consumer in last_used_by.items():
node_to_last_usage[consumer].add(producer)

# loop through nodes in order of the forward graph and keep track of the following:
# for each node, right before its execution, the output of what nodes are in memory.
@dataclass
class NodeCutScore:
tot_mem: float
alive_node_names: OrderedSet[str]

alive_nodes = OrderedSet()
node2score = {}
for node in fwd_recomputable_nodes:
if not must_recompute(node):
alive_nodes.add(node)
for a in node_to_last_usage[node]:
alive_nodes.remove(a)
tot_mem = sum(_size_of(node) for node in alive_nodes)
node2score[node] = NodeCutScore(
tot_mem, OrderedSet([n.name for n in alive_nodes])
)

# divide the graph into stages with roughly equal memory usage.
stages = defaultdict(OrderedSet)
cum_mem_so_far = 0
curr_stage_idx = 0

if stage_size_in_GiB is None:
total_used_memory = sum(
_size_of(node)
for node in fwd_recomputable_nodes
if not must_recompute(node)
)
total_used_memory_in_GiB = total_used_memory / 2**30
stage_size_in_GiB = total_used_memory_in_GiB**0.5
print(f"Computed stage_size {stage_size_in_GiB=}")
target_mem = stage_size_in_GiB * 2**30
for node in fwd_recomputable_nodes:
stages[curr_stage_idx].add(node)
if not must_recompute(node):
cum_mem_so_far += _size_of(node)
if cum_mem_so_far >= target_mem:
curr_stage_idx += 1
cum_mem_so_far = 0

# loop through each stage and pick the best node to cut on, and save
# the nodes that will be marked as must save.
nodes_to_save = OrderedSet()
for stage in stages.values():
best_node = min(stage, key=lambda x: node2score[x].tot_mem)
nodes_to_save.update(node2score[best_node].alive_node_names)
mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save])

save_list = {
torch.ops.aten.mm.default,
# torch.ops.aten._scaled_dot_product_efficient_attention.default,
# torch.ops.aten._scaled_dot_product_flash_attention.default,
}
must_save_nodes = []
counter = 0
for node in fwd_recomputable_nodes:
if node.target in save_list:
if node.target == torch.ops.aten.mm.default:
if counter % 2 == 0:
counter += 1
else:
counter += 1
continue
must_save_nodes.append(node)
mark_nodes_as_must_save(must_save_nodes)
Comment on lines +285 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I think this needs to be moved out of this function, as it was just a specific tagging for Llama and we need to add user control of it



def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0):
force_recompute_fsdp_all_gather(graph)
mark_nodes_as_must_save_to_stage_recomputation(
graph, stage_size_in_GiB=ac_stage_size_in_GiB
)
10 changes: 9 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState

from .activation_checkpointing import ac_joint_pass
from .apply_sharding import apply_sharding_to_model
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
from .init_weights import hook_params_setters
Expand Down Expand Up @@ -203,6 +204,9 @@ def __init__(
mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy] = None,
compile: bool = False,
enable_ac: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wouldn't pass this as a boolean flag in the constructor. We might need to add other AC tags as well (like save some matmuls, which is what I did for Llama3 8B). So I'd maybe keep this outside of the constructor somehow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, i thought a bit about this, and I don't have a strong opinion about where the knob goes, but I think we need the knob, at least for debugging purposes. I'd also prefer it not to be something hacky like an ENV or a global state, because these can be harder to control from submitting mast CLI jobs. It's pretty nice to be able to wire up all the things we care about to torchtitan CLI.

would you feel better about making a method configure_ac(enable:bool=True, size:optional[float]=None, ...) that users can call on the autop object?

Copy link
Contributor

@bdhirsh bdhirsh Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit confused about the API design. How are we reconciling these 3 things:

(1) the user might want to use compiler-driven AC. I think that this is.. effectively what this PR tries to support through the enable_ac flag through the custom graph pass, is that right? What is the relationship between the above graph pass and the autoAC code that lives directly in the partitioner - do we plan to reconcile the two in the future?

(2) the user might be manually applying AC in their single-gpu model code (my understanding is that we actually do support this, independently from this PR)

(3) we generally want to ensure that AGs for FSDP are always recomputed. This PR also does this. One thing I'm not clear about, though: don't we generally always want to recompute FSDP allgathers, and if so, why make this part of the flag at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we need to document it a bit better.

The way it works (iirc) is

  • fsdp allgathers are always recomputed, flag doesn't change this
  • "auto": implement the 'stage-size heuristic' to determine stage size for the autoac in this PR
  • Int: use autoac in this PR, but don't compute the stage size automatically, use this size
  • None: passthrough user annotations (useful if they passed in ac hints)

note: this commeent should apply to what's on main, not whats in this PR. The delta being None vs Auto, which I merged together in this PR but @fmassa later unmerged

# None means 'auto'
ac_stage_size_in_GiB: Optional[float] = None,
):
self.stack = ExitStack()
self.fake_mode = (
Expand All @@ -228,9 +232,10 @@ def __init__(
self.input_fn = input_fn
self.mesh = mesh
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB

# NB: rest of the construction happens in __enter__

self.active = False

def __enter__(self):
Expand Down Expand Up @@ -439,6 +444,9 @@ def apply_placement(self, sharding_placement=None):
},
payload_fn=lambda: str(parallel_gm.graph),
)

if self.enable_ac:
ac_joint_pass(parallel_gm.graph, self.ac_stage_size_in_GiB)
# now rename input/param/tangent/output/grad_param/grad_input nodes following
# our convention
# apply_node_renaming(
Expand Down