|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +import logging |
| 6 | +from collections import defaultdict |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import Optional |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch._functorch.partitioners import _has_tag_is_backward, _size_of |
| 12 | +from torch.utils._ordered_set import OrderedSet |
| 13 | +from torch.utils.checkpoint import CheckpointPolicy |
| 14 | + |
| 15 | +logger: logging.Logger = logging.getLogger(__name__) |
| 16 | +logger.setLevel(logging.INFO) |
| 17 | + |
| 18 | + |
| 19 | +# reimplement torch._functorch.partitioners.must_recompute |
| 20 | +# to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE |
| 21 | +# For now there isn't any distinction in the partitioner between both |
| 22 | +# and I think this is a bug |
| 23 | +def must_recompute(node: torch.fx.Node) -> bool: |
| 24 | + return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE |
| 25 | + |
| 26 | + |
| 27 | +def is_graph_input(node: torch.fx.Node) -> bool: |
| 28 | + return node.op == "placeholder" |
| 29 | + |
| 30 | + |
| 31 | +def is_wait_tensor(node: torch.fx.Node) -> bool: |
| 32 | + return ( |
| 33 | + node.op == "call_function" |
| 34 | + and node.target == torch.ops._c10d_functional.wait_tensor.default |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: |
| 39 | + return ( |
| 40 | + node.op == "call_function" |
| 41 | + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: |
| 46 | + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): |
| 47 | + # TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait, |
| 48 | + # maybe just 2D FSDP |
| 49 | + # ag_node = node.args[0] |
| 50 | + # assert is_graph_input(ag_node.args[0]) or ( |
| 51 | + # ag_node.args[0].op == "call_function" |
| 52 | + # and ag_node.args[0].target == torch.ops.prims.convert_element_type.default |
| 53 | + # and is_graph_input(ag_node.args[0].args[0]) |
| 54 | + # ), ( |
| 55 | + # "Assume all_gather_into_tensor input is either graph input " |
| 56 | + # + f"or dtype conversion of graph input, but got {ag_node.args[0]}" |
| 57 | + # ) |
| 58 | + return True |
| 59 | + return False |
| 60 | + |
| 61 | + |
| 62 | +# mypy: ignore-errors |
| 63 | + |
| 64 | + |
| 65 | +def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None: |
| 66 | + """ |
| 67 | + Force recompute all_gather nodes from simple fsdp in the graph. |
| 68 | + This pass should be added in torch._inductor.config.joint_custom_post_pass |
| 69 | + """ |
| 70 | + |
| 71 | + def force_recompute_node(node): |
| 72 | + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE |
| 73 | + if "ac_graph_id" not in node.meta: |
| 74 | + # ac_graph_id is used in the partitioner to decide |
| 75 | + # if two nodes which have AC applied come from a different |
| 76 | + # AC regions. This is needed because nodes in the boundary |
| 77 | + # of two AC regions are marked as MUST_SAVE. In our case |
| 78 | + # we just add a large value of ac_graph_id so that |
| 79 | + # all nodes we tag for recomputation do indeed get recomputed |
| 80 | + # and are not influenced by other nodes in the graph with |
| 81 | + # nearby ac_graph_id values |
| 82 | + node.meta["ac_graph_id"] = 100000 |
| 83 | + |
| 84 | + # Make all-gather nodes (and related nodes) recomputable, to circumvent |
| 85 | + # https://github.com/pytorch/pytorch/issues/136433 |
| 86 | + for node in graph.nodes: |
| 87 | + if is_wait_tensor_from_fsdp(node): |
| 88 | + ag_node = node.args[0] |
| 89 | + force_recompute_node(ag_node) # all_gather |
| 90 | + force_recompute_node(node) # wait_tensor |
| 91 | + # Force-recompute slice that comes after wait |
| 92 | + for user in node.users: |
| 93 | + if ( |
| 94 | + user.op == "call_function" |
| 95 | + and user.target == torch.ops.aten.slice.Tensor |
| 96 | + ): |
| 97 | + force_recompute_node(user) |
| 98 | + # Force-recompute potential dtype casts from all_gather |
| 99 | + if ( |
| 100 | + ag_node.all_input_nodes[0].op == "call_function" |
| 101 | + and ag_node.args[0].target |
| 102 | + == torch.ops.prims.convert_element_type.default |
| 103 | + ): |
| 104 | + force_recompute_node(ag_node.all_input_nodes[0]) |
| 105 | + |
| 106 | + |
| 107 | +def mark_nodes_as_must_save_to_stage_recomputation( |
| 108 | + joint_graph: torch.fx.Graph, |
| 109 | + stage_size_in_GiB: Optional[float] = None, |
| 110 | +) -> None: |
| 111 | + """ |
| 112 | + Marks specific nodes as "must save" to optimize memory usage during recomputation. |
| 113 | + With aggressive recomputation strategies, we often encounter situations where long chains |
| 114 | + of forward nodes must be recomputed before executing backward pass nodes, causing high |
| 115 | + peak memory usage. This function breaks these recomputation chains into smaller stages |
| 116 | + based by periodically saving itermediate nodes, keeping peak memory usage below. |
| 117 | + Args: |
| 118 | + joint_graph: The joint graph containing both forward and backward nodes |
| 119 | + stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging) |
| 120 | + """ |
| 121 | + INT_INF = int(1e9) |
| 122 | + |
| 123 | + def get_required_fwd_nodes( |
| 124 | + joint_graph: torch.fx.Graph, |
| 125 | + ) -> OrderedSet[torch.fx.Node]: |
| 126 | + """ |
| 127 | + Return the set of nodes that are required in the forward graph. |
| 128 | + NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py |
| 129 | + where nodes are classified into three types -- fwd, bwd, and unclaimed |
| 130 | + both bwd and unclaimed nodes have partitioner_tag equal to "is_backward" |
| 131 | + """ |
| 132 | + required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet() |
| 133 | + for node in joint_graph.nodes: |
| 134 | + if node.op == "placeholder" and "tangents" in node.target: |
| 135 | + continue |
| 136 | + if node.op == "output": |
| 137 | + continue |
| 138 | + if _has_tag_is_backward(node): |
| 139 | + continue |
| 140 | + required_fwd_nodes.add(node) |
| 141 | + return required_fwd_nodes |
| 142 | + |
| 143 | + def get_node_distance_to_bwd( |
| 144 | + joint_graph: torch.fx.Graph, |
| 145 | + get_required_fwd_nodes: OrderedSet[torch.fx.Node], |
| 146 | + ) -> dict[torch.fx.Node, int]: |
| 147 | + """ |
| 148 | + Compute and return the distance of all nodes to the closest backward node. |
| 149 | + If a node is not an ancestor of a backward node, then its distance is INT_INF. |
| 150 | + NOTE: this is adapted from |
| 151 | + https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097 |
| 152 | + """ |
| 153 | + dist_from_bw = {} |
| 154 | + for node in reversed(joint_graph.nodes): |
| 155 | + if node.op == "output": |
| 156 | + dist_from_bw[node] = INT_INF |
| 157 | + elif node not in get_required_fwd_nodes: |
| 158 | + dist_from_bw[node] = 0 |
| 159 | + else: |
| 160 | + dist_from_bw[node] = INT_INF |
| 161 | + for user in node.users: |
| 162 | + dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1) |
| 163 | + return dist_from_bw |
| 164 | + |
| 165 | + def get_all_recomputable_forward_nodes( |
| 166 | + joint_graph: torch.fx.Graph, |
| 167 | + ) -> OrderedSet[torch.fx.Node]: |
| 168 | + """ |
| 169 | + Return the set of all forward nodes that are recomputable |
| 170 | + """ |
| 171 | + required_fwd_nodes = get_required_fwd_nodes(joint_graph) |
| 172 | + dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes) |
| 173 | + fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet() |
| 174 | + for node in joint_graph.nodes: |
| 175 | + if ( |
| 176 | + node in required_fwd_nodes |
| 177 | + and dist_from_bw[node] < INT_INF |
| 178 | + and node.op != "placeholder" |
| 179 | + ): |
| 180 | + fwd_recomputable_nodes.add(node) |
| 181 | + return fwd_recomputable_nodes |
| 182 | + |
| 183 | + def mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None: |
| 184 | + """ |
| 185 | + Given a list of nodes, mark them as must save. |
| 186 | + """ |
| 187 | + print(f"mark_nodes_as_must_save: {must_save_nodes}") |
| 188 | + for node in must_save_nodes: |
| 189 | + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE |
| 190 | + |
| 191 | + fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph) |
| 192 | + |
| 193 | + # Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below |
| 194 | + for node in fwd_recomputable_nodes: |
| 195 | + if node.meta.get("recompute", None) is not None: |
| 196 | + # do not mess with allgather nodes that have already been marked recompute! |
| 197 | + continue |
| 198 | + node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE |
| 199 | + # add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine |
| 200 | + # and is the same we add for the all-gather nodes |
| 201 | + node.meta["ac_graph_id"] = 100000 |
| 202 | + |
| 203 | + # get the mapping between node name and node |
| 204 | + name_to_node_mapping = {} |
| 205 | + for node in fwd_recomputable_nodes: |
| 206 | + name_to_node_mapping[node.name] = node |
| 207 | + |
| 208 | + # populate node_to_predecessors, accounting for must_recompute nodes. In particular, |
| 209 | + # if a node is marked as must recompute, then for its users, their predecessors should |
| 210 | + # be updated to be instead the predecessors of the must recompute node. |
| 211 | + node_to_predecessors = defaultdict(OrderedSet) |
| 212 | + for node in fwd_recomputable_nodes: |
| 213 | + node_to_predecessors[node] = OrderedSet( |
| 214 | + [pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes] |
| 215 | + ) |
| 216 | + for node in fwd_recomputable_nodes: |
| 217 | + if not must_recompute(node): |
| 218 | + continue |
| 219 | + for user in node.users: |
| 220 | + if user in fwd_recomputable_nodes: |
| 221 | + node_to_predecessors[user].remove(node) |
| 222 | + node_to_predecessors[user].update(node_to_predecessors[node]) |
| 223 | + |
| 224 | + # populate node_to_last_usage |
| 225 | + # if A is last used by B, then A \in node_to_last_usage[B] |
| 226 | + node_to_last_usage = defaultdict(OrderedSet) |
| 227 | + last_used_by = {} |
| 228 | + for node in fwd_recomputable_nodes: |
| 229 | + last_used_by[node] = node |
| 230 | + for pred in node_to_predecessors[node]: |
| 231 | + last_used_by[pred] = node |
| 232 | + for producer, consumer in last_used_by.items(): |
| 233 | + node_to_last_usage[consumer].add(producer) |
| 234 | + |
| 235 | + # loop through nodes in order of the forward graph and keep track of the following: |
| 236 | + # for each node, right before its execution, the output of what nodes are in memory. |
| 237 | + @dataclass |
| 238 | + class NodeCutScore: |
| 239 | + tot_mem: float |
| 240 | + alive_node_names: OrderedSet[str] |
| 241 | + |
| 242 | + alive_nodes = OrderedSet() |
| 243 | + node2score = {} |
| 244 | + for node in fwd_recomputable_nodes: |
| 245 | + if not must_recompute(node): |
| 246 | + alive_nodes.add(node) |
| 247 | + for a in node_to_last_usage[node]: |
| 248 | + alive_nodes.remove(a) |
| 249 | + tot_mem = sum(_size_of(node) for node in alive_nodes) |
| 250 | + node2score[node] = NodeCutScore( |
| 251 | + tot_mem, OrderedSet([n.name for n in alive_nodes]) |
| 252 | + ) |
| 253 | + |
| 254 | + # divide the graph into stages with roughly equal memory usage. |
| 255 | + stages = defaultdict(OrderedSet) |
| 256 | + cum_mem_so_far = 0 |
| 257 | + curr_stage_idx = 0 |
| 258 | + |
| 259 | + if stage_size_in_GiB is None: |
| 260 | + total_used_memory = sum( |
| 261 | + _size_of(node) |
| 262 | + for node in fwd_recomputable_nodes |
| 263 | + if not must_recompute(node) |
| 264 | + ) |
| 265 | + total_used_memory_in_GiB = total_used_memory / 2**30 |
| 266 | + stage_size_in_GiB = total_used_memory_in_GiB**0.5 |
| 267 | + print(f"Computed stage_size {stage_size_in_GiB=}") |
| 268 | + target_mem = stage_size_in_GiB * 2**30 |
| 269 | + for node in fwd_recomputable_nodes: |
| 270 | + stages[curr_stage_idx].add(node) |
| 271 | + if not must_recompute(node): |
| 272 | + cum_mem_so_far += _size_of(node) |
| 273 | + if cum_mem_so_far >= target_mem: |
| 274 | + curr_stage_idx += 1 |
| 275 | + cum_mem_so_far = 0 |
| 276 | + |
| 277 | + # loop through each stage and pick the best node to cut on, and save |
| 278 | + # the nodes that will be marked as must save. |
| 279 | + nodes_to_save = OrderedSet() |
| 280 | + for stage in stages.values(): |
| 281 | + best_node = min(stage, key=lambda x: node2score[x].tot_mem) |
| 282 | + nodes_to_save.update(node2score[best_node].alive_node_names) |
| 283 | + mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save]) |
| 284 | + |
| 285 | + save_list = { |
| 286 | + torch.ops.aten.mm.default, |
| 287 | + # torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 288 | + # torch.ops.aten._scaled_dot_product_flash_attention.default, |
| 289 | + } |
| 290 | + must_save_nodes = [] |
| 291 | + counter = 0 |
| 292 | + for node in fwd_recomputable_nodes: |
| 293 | + if node.target in save_list: |
| 294 | + if node.target == torch.ops.aten.mm.default: |
| 295 | + if counter % 2 == 0: |
| 296 | + counter += 1 |
| 297 | + else: |
| 298 | + counter += 1 |
| 299 | + continue |
| 300 | + must_save_nodes.append(node) |
| 301 | + mark_nodes_as_must_save(must_save_nodes) |
| 302 | + |
| 303 | + |
| 304 | +def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0): |
| 305 | + force_recompute_fsdp_all_gather(graph) |
| 306 | + mark_nodes_as_must_save_to_stage_recomputation( |
| 307 | + graph, stage_size_in_GiB=ac_stage_size_in_GiB |
| 308 | + ) |
0 commit comments