-
Notifications
You must be signed in to change notification settings - Fork 9
Add Activation Checkpointing Pass #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,308 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| import logging | ||
| from collections import defaultdict | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch._functorch.partitioners import _has_tag_is_backward, _size_of | ||
| from torch.utils._ordered_set import OrderedSet | ||
| from torch.utils.checkpoint import CheckpointPolicy | ||
|
|
||
| logger: logging.Logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.INFO) | ||
|
|
||
|
|
||
| # reimplement torch._functorch.partitioners.must_recompute | ||
| # to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE | ||
| # For now there isn't any distinction in the partitioner between both | ||
| # and I think this is a bug | ||
| def must_recompute(node: torch.fx.Node) -> bool: | ||
| return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE | ||
|
|
||
|
|
||
| def is_graph_input(node: torch.fx.Node) -> bool: | ||
| return node.op == "placeholder" | ||
|
|
||
|
|
||
| def is_wait_tensor(node: torch.fx.Node) -> bool: | ||
| return ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops._c10d_functional.wait_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: | ||
| return ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default | ||
| ) | ||
|
|
||
|
|
||
| def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: | ||
| if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): | ||
| # TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait, | ||
| # maybe just 2D FSDP | ||
| # ag_node = node.args[0] | ||
| # assert is_graph_input(ag_node.args[0]) or ( | ||
| # ag_node.args[0].op == "call_function" | ||
| # and ag_node.args[0].target == torch.ops.prims.convert_element_type.default | ||
| # and is_graph_input(ag_node.args[0].args[0]) | ||
| # ), ( | ||
| # "Assume all_gather_into_tensor input is either graph input " | ||
| # + f"or dtype conversion of graph input, but got {ag_node.args[0]}" | ||
| # ) | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| # mypy: ignore-errors | ||
|
|
||
|
|
||
| def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None: | ||
| """ | ||
| Force recompute all_gather nodes from simple fsdp in the graph. | ||
| This pass should be added in torch._inductor.config.joint_custom_post_pass | ||
| """ | ||
|
|
||
| def force_recompute_node(node): | ||
| node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE | ||
| if "ac_graph_id" not in node.meta: | ||
| # ac_graph_id is used in the partitioner to decide | ||
| # if two nodes which have AC applied come from a different | ||
| # AC regions. This is needed because nodes in the boundary | ||
| # of two AC regions are marked as MUST_SAVE. In our case | ||
| # we just add a large value of ac_graph_id so that | ||
| # all nodes we tag for recomputation do indeed get recomputed | ||
| # and are not influenced by other nodes in the graph with | ||
| # nearby ac_graph_id values | ||
| node.meta["ac_graph_id"] = 100000 | ||
|
|
||
| # Make all-gather nodes (and related nodes) recomputable, to circumvent | ||
| # https://github.com/pytorch/pytorch/issues/136433 | ||
| for node in graph.nodes: | ||
| if is_wait_tensor_from_fsdp(node): | ||
| ag_node = node.args[0] | ||
| force_recompute_node(ag_node) # all_gather | ||
| force_recompute_node(node) # wait_tensor | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm if I'm following, right now the logic here kind of needs to mirror the logic in It might be cleaner just to "force recompute every node in the chain between the current wait tensor and its param" here, that way we don't have to worry about this code and the code in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think it would be a good approach to improve this part. Ideally though we would need to have some heuristic like all the downstream users of the parameter up until the wait only have a single input node, otherwise we might end-up mixing two-operand operations which are not from FSDP-like ops. |
||
| # Force-recompute slice that comes after wait | ||
| for user in node.users: | ||
| if ( | ||
| user.op == "call_function" | ||
| and user.target == torch.ops.aten.slice.Tensor | ||
| ): | ||
| force_recompute_node(user) | ||
| # Force-recompute potential dtype casts from all_gather | ||
| if ( | ||
| ag_node.all_input_nodes[0].op == "call_function" | ||
| and ag_node.args[0].target | ||
| == torch.ops.prims.convert_element_type.default | ||
| ): | ||
| force_recompute_node(ag_node.all_input_nodes[0]) | ||
|
|
||
|
|
||
| def mark_nodes_as_must_save_to_stage_recomputation( | ||
| joint_graph: torch.fx.Graph, | ||
| stage_size_in_GiB: Optional[float] = None, | ||
| ) -> None: | ||
| """ | ||
| Marks specific nodes as "must save" to optimize memory usage during recomputation. | ||
| With aggressive recomputation strategies, we often encounter situations where long chains | ||
| of forward nodes must be recomputed before executing backward pass nodes, causing high | ||
| peak memory usage. This function breaks these recomputation chains into smaller stages | ||
| based by periodically saving itermediate nodes, keeping peak memory usage below. | ||
| Args: | ||
| joint_graph: The joint graph containing both forward and backward nodes | ||
| stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI this is missing the option to disable the AutoAC. Before -1 would early return, now it just errors? |
||
| """ | ||
| INT_INF = int(1e9) | ||
|
|
||
| def get_required_fwd_nodes( | ||
| joint_graph: torch.fx.Graph, | ||
| ) -> OrderedSet[torch.fx.Node]: | ||
| """ | ||
| Return the set of nodes that are required in the forward graph. | ||
| NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py | ||
| where nodes are classified into three types -- fwd, bwd, and unclaimed | ||
| both bwd and unclaimed nodes have partitioner_tag equal to "is_backward" | ||
| """ | ||
| required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet() | ||
| for node in joint_graph.nodes: | ||
| if node.op == "placeholder" and "tangents" in node.target: | ||
| continue | ||
| if node.op == "output": | ||
| continue | ||
| if _has_tag_is_backward(node): | ||
| continue | ||
| required_fwd_nodes.add(node) | ||
| return required_fwd_nodes | ||
|
|
||
| def get_node_distance_to_bwd( | ||
| joint_graph: torch.fx.Graph, | ||
| get_required_fwd_nodes: OrderedSet[torch.fx.Node], | ||
| ) -> dict[torch.fx.Node, int]: | ||
| """ | ||
| Compute and return the distance of all nodes to the closest backward node. | ||
| If a node is not an ancestor of a backward node, then its distance is INT_INF. | ||
| NOTE: this is adapted from | ||
| https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097 | ||
| """ | ||
| dist_from_bw = {} | ||
| for node in reversed(joint_graph.nodes): | ||
| if node.op == "output": | ||
| dist_from_bw[node] = INT_INF | ||
| elif node not in get_required_fwd_nodes: | ||
| dist_from_bw[node] = 0 | ||
| else: | ||
| dist_from_bw[node] = INT_INF | ||
| for user in node.users: | ||
| dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1) | ||
| return dist_from_bw | ||
|
|
||
| def get_all_recomputable_forward_nodes( | ||
| joint_graph: torch.fx.Graph, | ||
| ) -> OrderedSet[torch.fx.Node]: | ||
| """ | ||
| Return the set of all forward nodes that are recomputable | ||
| """ | ||
| required_fwd_nodes = get_required_fwd_nodes(joint_graph) | ||
| dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes) | ||
| fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet() | ||
| for node in joint_graph.nodes: | ||
| if ( | ||
| node in required_fwd_nodes | ||
| and dist_from_bw[node] < INT_INF | ||
| and node.op != "placeholder" | ||
| ): | ||
| fwd_recomputable_nodes.add(node) | ||
| return fwd_recomputable_nodes | ||
|
|
||
| def mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None: | ||
| """ | ||
| Given a list of nodes, mark them as must save. | ||
| """ | ||
| print(f"mark_nodes_as_must_save: {must_save_nodes}") | ||
| for node in must_save_nodes: | ||
| node.meta["recompute"] = CheckpointPolicy.MUST_SAVE | ||
|
|
||
| fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph) | ||
|
|
||
| # Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below | ||
| for node in fwd_recomputable_nodes: | ||
| if node.meta.get("recompute", None) is not None: | ||
| # do not mess with allgather nodes that have already been marked recompute! | ||
| continue | ||
| node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE | ||
| # add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine | ||
| # and is the same we add for the all-gather nodes | ||
| node.meta["ac_graph_id"] = 100000 | ||
|
|
||
| # get the mapping between node name and node | ||
| name_to_node_mapping = {} | ||
| for node in fwd_recomputable_nodes: | ||
| name_to_node_mapping[node.name] = node | ||
|
|
||
| # populate node_to_predecessors, accounting for must_recompute nodes. In particular, | ||
| # if a node is marked as must recompute, then for its users, their predecessors should | ||
| # be updated to be instead the predecessors of the must recompute node. | ||
| node_to_predecessors = defaultdict(OrderedSet) | ||
| for node in fwd_recomputable_nodes: | ||
| node_to_predecessors[node] = OrderedSet( | ||
| [pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes] | ||
| ) | ||
| for node in fwd_recomputable_nodes: | ||
| if not must_recompute(node): | ||
| continue | ||
| for user in node.users: | ||
| if user in fwd_recomputable_nodes: | ||
| node_to_predecessors[user].remove(node) | ||
| node_to_predecessors[user].update(node_to_predecessors[node]) | ||
|
|
||
| # populate node_to_last_usage | ||
| # if A is last used by B, then A \in node_to_last_usage[B] | ||
| node_to_last_usage = defaultdict(OrderedSet) | ||
| last_used_by = {} | ||
| for node in fwd_recomputable_nodes: | ||
| last_used_by[node] = node | ||
| for pred in node_to_predecessors[node]: | ||
| last_used_by[pred] = node | ||
| for producer, consumer in last_used_by.items(): | ||
| node_to_last_usage[consumer].add(producer) | ||
|
|
||
| # loop through nodes in order of the forward graph and keep track of the following: | ||
| # for each node, right before its execution, the output of what nodes are in memory. | ||
| @dataclass | ||
| class NodeCutScore: | ||
| tot_mem: float | ||
| alive_node_names: OrderedSet[str] | ||
|
|
||
| alive_nodes = OrderedSet() | ||
| node2score = {} | ||
| for node in fwd_recomputable_nodes: | ||
| if not must_recompute(node): | ||
| alive_nodes.add(node) | ||
| for a in node_to_last_usage[node]: | ||
| alive_nodes.remove(a) | ||
| tot_mem = sum(_size_of(node) for node in alive_nodes) | ||
| node2score[node] = NodeCutScore( | ||
| tot_mem, OrderedSet([n.name for n in alive_nodes]) | ||
| ) | ||
|
|
||
| # divide the graph into stages with roughly equal memory usage. | ||
| stages = defaultdict(OrderedSet) | ||
| cum_mem_so_far = 0 | ||
| curr_stage_idx = 0 | ||
|
|
||
| if stage_size_in_GiB is None: | ||
| total_used_memory = sum( | ||
| _size_of(node) | ||
| for node in fwd_recomputable_nodes | ||
| if not must_recompute(node) | ||
| ) | ||
| total_used_memory_in_GiB = total_used_memory / 2**30 | ||
| stage_size_in_GiB = total_used_memory_in_GiB**0.5 | ||
| print(f"Computed stage_size {stage_size_in_GiB=}") | ||
| target_mem = stage_size_in_GiB * 2**30 | ||
| for node in fwd_recomputable_nodes: | ||
| stages[curr_stage_idx].add(node) | ||
| if not must_recompute(node): | ||
| cum_mem_so_far += _size_of(node) | ||
| if cum_mem_so_far >= target_mem: | ||
| curr_stage_idx += 1 | ||
| cum_mem_so_far = 0 | ||
|
|
||
| # loop through each stage and pick the best node to cut on, and save | ||
| # the nodes that will be marked as must save. | ||
| nodes_to_save = OrderedSet() | ||
| for stage in stages.values(): | ||
| best_node = min(stage, key=lambda x: node2score[x].tot_mem) | ||
| nodes_to_save.update(node2score[best_node].alive_node_names) | ||
| mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save]) | ||
|
|
||
| save_list = { | ||
| torch.ops.aten.mm.default, | ||
| # torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
| # torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
| } | ||
| must_save_nodes = [] | ||
| counter = 0 | ||
| for node in fwd_recomputable_nodes: | ||
| if node.target in save_list: | ||
| if node.target == torch.ops.aten.mm.default: | ||
| if counter % 2 == 0: | ||
| counter += 1 | ||
| else: | ||
| counter += 1 | ||
| continue | ||
| must_save_nodes.append(node) | ||
| mark_nodes_as_must_save(must_save_nodes) | ||
|
Comment on lines
+285
to
+301
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI I think this needs to be moved out of this function, as it was just a specific tagging for Llama and we need to add user control of it |
||
|
|
||
|
|
||
| def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0): | ||
| force_recompute_fsdp_all_gather(graph) | ||
| mark_nodes_as_must_save_to_stage_recomputation( | ||
| graph, stage_size_in_GiB=ac_stage_size_in_GiB | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| from torch.fx import GraphModule | ||
| from torch.fx.experimental._backward_state import BackwardState | ||
|
|
||
| from .activation_checkpointing import ac_joint_pass | ||
| from .apply_sharding import apply_sharding_to_model | ||
| from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast | ||
| from .init_weights import hook_params_setters | ||
|
|
@@ -203,6 +204,9 @@ def __init__( | |
| mesh: DeviceMesh, | ||
| mp_policy: Optional[MixedPrecisionPolicy] = None, | ||
| compile: bool = False, | ||
| enable_ac: bool = True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I wouldn't pass this as a boolean flag in the constructor. We might need to add other AC tags as well (like save some matmuls, which is what I did for Llama3 8B). So I'd maybe keep this outside of the constructor somehow
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so, i thought a bit about this, and I don't have a strong opinion about where the knob goes, but I think we need the knob, at least for debugging purposes. I'd also prefer it not to be something hacky like an ENV or a global state, because these can be harder to control from submitting mast CLI jobs. It's pretty nice to be able to wire up all the things we care about to torchtitan CLI. would you feel better about making a method
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm a bit confused about the API design. How are we reconciling these 3 things: (1) the user might want to use compiler-driven AC. I think that this is.. effectively what this PR tries to support through the (2) the user might be manually applying AC in their single-gpu model code (my understanding is that we actually do support this, independently from this PR) (3) we generally want to ensure that AGs for FSDP are always recomputed. This PR also does this. One thing I'm not clear about, though: don't we generally always want to recompute FSDP allgathers, and if so, why make this part of the flag at all?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we need to document it a bit better. The way it works (iirc) is
note: this commeent should apply to what's on main, not whats in this PR. The delta being None vs Auto, which I merged together in this PR but @fmassa later unmerged |
||
| # None means 'auto' | ||
| ac_stage_size_in_GiB: Optional[float] = None, | ||
| ): | ||
| self.stack = ExitStack() | ||
| self.fake_mode = ( | ||
|
|
@@ -228,9 +232,10 @@ def __init__( | |
| self.input_fn = input_fn | ||
| self.mesh = mesh | ||
| self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta | ||
| self.enable_ac = enable_ac | ||
| self.ac_stage_size_in_GiB = ac_stage_size_in_GiB | ||
|
|
||
| # NB: rest of the construction happens in __enter__ | ||
|
|
||
| self.active = False | ||
|
|
||
| def __enter__(self): | ||
|
|
@@ -439,6 +444,9 @@ def apply_placement(self, sharding_placement=None): | |
| }, | ||
| payload_fn=lambda: str(parallel_gm.graph), | ||
| ) | ||
|
|
||
| if self.enable_ac: | ||
| ac_joint_pass(parallel_gm.graph, self.ac_stage_size_in_GiB) | ||
| # now rename input/param/tangent/output/grad_param/grad_input nodes following | ||
| # our convention | ||
| # apply_node_renaming( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably will want to check if we want to fix this, or just always recompute all-gathers in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, would it ever be profitable to recompute an allgather for e.g. context parallelism? And would it ever be better to not recompute an allgather for fsdp?
If we want to make this all about FSDP for now, i think we could add some tagging to the nodes when we do the my_redistribute_tensor for param nodes and snapshot the fact that it's for FSDP.