Skip to content

Commit c65facd

Browse files
authored
Add Activation Checkpointing Pass (#83)
Work in progress: copied the latest code over from `whc/hack_aot` and tweaked the way it gets hooked up a bit, haven't tested yet. Likely need to discuss whether we want the AC pass to be popped back off inductor's passes earlier or keep it at __exit__ from AutoParallel.
1 parent 75cef61 commit c65facd

File tree

2 files changed

+317
-1
lines changed

2 files changed

+317
-1
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import logging
6+
from collections import defaultdict
7+
from dataclasses import dataclass
8+
from typing import Optional
9+
10+
import torch
11+
from torch._functorch.partitioners import _has_tag_is_backward, _size_of
12+
from torch.utils._ordered_set import OrderedSet
13+
from torch.utils.checkpoint import CheckpointPolicy
14+
15+
logger: logging.Logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.INFO)
17+
18+
19+
# reimplement torch._functorch.partitioners.must_recompute
20+
# to only check for MUST_RECOMPUTE tag, and not PREFER_RECOMPUTE
21+
# For now there isn't any distinction in the partitioner between both
22+
# and I think this is a bug
23+
def must_recompute(node: torch.fx.Node) -> bool:
24+
return node.meta.get("recompute", None) is CheckpointPolicy.MUST_RECOMPUTE
25+
26+
27+
def is_graph_input(node: torch.fx.Node) -> bool:
28+
return node.op == "placeholder"
29+
30+
31+
def is_wait_tensor(node: torch.fx.Node) -> bool:
32+
return (
33+
node.op == "call_function"
34+
and node.target == torch.ops._c10d_functional.wait_tensor.default
35+
)
36+
37+
38+
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
39+
return (
40+
node.op == "call_function"
41+
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
42+
)
43+
44+
45+
def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
46+
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
47+
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
48+
# maybe just 2D FSDP
49+
# ag_node = node.args[0]
50+
# assert is_graph_input(ag_node.args[0]) or (
51+
# ag_node.args[0].op == "call_function"
52+
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
53+
# and is_graph_input(ag_node.args[0].args[0])
54+
# ), (
55+
# "Assume all_gather_into_tensor input is either graph input "
56+
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
57+
# )
58+
return True
59+
return False
60+
61+
62+
# mypy: ignore-errors
63+
64+
65+
def force_recompute_fsdp_all_gather(graph: torch.fx.Graph) -> None:
66+
"""
67+
Force recompute all_gather nodes from simple fsdp in the graph.
68+
This pass should be added in torch._inductor.config.joint_custom_post_pass
69+
"""
70+
71+
def force_recompute_node(node):
72+
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
73+
if "ac_graph_id" not in node.meta:
74+
# ac_graph_id is used in the partitioner to decide
75+
# if two nodes which have AC applied come from a different
76+
# AC regions. This is needed because nodes in the boundary
77+
# of two AC regions are marked as MUST_SAVE. In our case
78+
# we just add a large value of ac_graph_id so that
79+
# all nodes we tag for recomputation do indeed get recomputed
80+
# and are not influenced by other nodes in the graph with
81+
# nearby ac_graph_id values
82+
node.meta["ac_graph_id"] = 100000
83+
84+
# Make all-gather nodes (and related nodes) recomputable, to circumvent
85+
# https://github.com/pytorch/pytorch/issues/136433
86+
for node in graph.nodes:
87+
if is_wait_tensor_from_fsdp(node):
88+
ag_node = node.args[0]
89+
force_recompute_node(ag_node) # all_gather
90+
force_recompute_node(node) # wait_tensor
91+
# Force-recompute slice that comes after wait
92+
for user in node.users:
93+
if (
94+
user.op == "call_function"
95+
and user.target == torch.ops.aten.slice.Tensor
96+
):
97+
force_recompute_node(user)
98+
# Force-recompute potential dtype casts from all_gather
99+
if (
100+
ag_node.all_input_nodes[0].op == "call_function"
101+
and ag_node.args[0].target
102+
== torch.ops.prims.convert_element_type.default
103+
):
104+
force_recompute_node(ag_node.all_input_nodes[0])
105+
106+
107+
def mark_nodes_as_must_save_to_stage_recomputation(
108+
joint_graph: torch.fx.Graph,
109+
stage_size_in_GiB: Optional[float] = None,
110+
) -> None:
111+
"""
112+
Marks specific nodes as "must save" to optimize memory usage during recomputation.
113+
With aggressive recomputation strategies, we often encounter situations where long chains
114+
of forward nodes must be recomputed before executing backward pass nodes, causing high
115+
peak memory usage. This function breaks these recomputation chains into smaller stages
116+
based by periodically saving itermediate nodes, keeping peak memory usage below.
117+
Args:
118+
joint_graph: The joint graph containing both forward and backward nodes
119+
stage_size_in_GiB: Target memory size per stage in GiB (-1 to disable staging)
120+
"""
121+
INT_INF = int(1e9)
122+
123+
def get_required_fwd_nodes(
124+
joint_graph: torch.fx.Graph,
125+
) -> OrderedSet[torch.fx.Node]:
126+
"""
127+
Return the set of nodes that are required in the forward graph.
128+
NOTE: this is doing similar things as classify_nodes() in _functorch/partitioenrs.py
129+
where nodes are classified into three types -- fwd, bwd, and unclaimed
130+
both bwd and unclaimed nodes have partitioner_tag equal to "is_backward"
131+
"""
132+
required_fwd_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
133+
for node in joint_graph.nodes:
134+
if node.op == "placeholder" and "tangents" in node.target:
135+
continue
136+
if node.op == "output":
137+
continue
138+
if _has_tag_is_backward(node):
139+
continue
140+
required_fwd_nodes.add(node)
141+
return required_fwd_nodes
142+
143+
def get_node_distance_to_bwd(
144+
joint_graph: torch.fx.Graph,
145+
get_required_fwd_nodes: OrderedSet[torch.fx.Node],
146+
) -> dict[torch.fx.Node, int]:
147+
"""
148+
Compute and return the distance of all nodes to the closest backward node.
149+
If a node is not an ancestor of a backward node, then its distance is INT_INF.
150+
NOTE: this is adapted from
151+
https://github.com/pytorch/pytorch/blob/3196a3aca0f16792820158cfd451cb977f99ac7e/torch/_functorch/partitioners.py#L2089-L2097
152+
"""
153+
dist_from_bw = {}
154+
for node in reversed(joint_graph.nodes):
155+
if node.op == "output":
156+
dist_from_bw[node] = INT_INF
157+
elif node not in get_required_fwd_nodes:
158+
dist_from_bw[node] = 0
159+
else:
160+
dist_from_bw[node] = INT_INF
161+
for user in node.users:
162+
dist_from_bw[node] = min(dist_from_bw[node], dist_from_bw[user] + 1)
163+
return dist_from_bw
164+
165+
def get_all_recomputable_forward_nodes(
166+
joint_graph: torch.fx.Graph,
167+
) -> OrderedSet[torch.fx.Node]:
168+
"""
169+
Return the set of all forward nodes that are recomputable
170+
"""
171+
required_fwd_nodes = get_required_fwd_nodes(joint_graph)
172+
dist_from_bw = get_node_distance_to_bwd(joint_graph, required_fwd_nodes)
173+
fwd_recomputable_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
174+
for node in joint_graph.nodes:
175+
if (
176+
node in required_fwd_nodes
177+
and dist_from_bw[node] < INT_INF
178+
and node.op != "placeholder"
179+
):
180+
fwd_recomputable_nodes.add(node)
181+
return fwd_recomputable_nodes
182+
183+
def mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None:
184+
"""
185+
Given a list of nodes, mark them as must save.
186+
"""
187+
print(f"mark_nodes_as_must_save: {must_save_nodes}")
188+
for node in must_save_nodes:
189+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
190+
191+
fwd_recomputable_nodes = get_all_recomputable_forward_nodes(joint_graph)
192+
193+
# Initialize all nodes as 'prefer recompute' and then adjust only the must-save ones below
194+
for node in fwd_recomputable_nodes:
195+
if node.meta.get("recompute", None) is not None:
196+
# do not mess with allgather nodes that have already been marked recompute!
197+
continue
198+
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
199+
# add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine
200+
# and is the same we add for the all-gather nodes
201+
node.meta["ac_graph_id"] = 100000
202+
203+
# get the mapping between node name and node
204+
name_to_node_mapping = {}
205+
for node in fwd_recomputable_nodes:
206+
name_to_node_mapping[node.name] = node
207+
208+
# populate node_to_predecessors, accounting for must_recompute nodes. In particular,
209+
# if a node is marked as must recompute, then for its users, their predecessors should
210+
# be updated to be instead the predecessors of the must recompute node.
211+
node_to_predecessors = defaultdict(OrderedSet)
212+
for node in fwd_recomputable_nodes:
213+
node_to_predecessors[node] = OrderedSet(
214+
[pred for pred in node.all_input_nodes if pred in fwd_recomputable_nodes]
215+
)
216+
for node in fwd_recomputable_nodes:
217+
if not must_recompute(node):
218+
continue
219+
for user in node.users:
220+
if user in fwd_recomputable_nodes:
221+
node_to_predecessors[user].remove(node)
222+
node_to_predecessors[user].update(node_to_predecessors[node])
223+
224+
# populate node_to_last_usage
225+
# if A is last used by B, then A \in node_to_last_usage[B]
226+
node_to_last_usage = defaultdict(OrderedSet)
227+
last_used_by = {}
228+
for node in fwd_recomputable_nodes:
229+
last_used_by[node] = node
230+
for pred in node_to_predecessors[node]:
231+
last_used_by[pred] = node
232+
for producer, consumer in last_used_by.items():
233+
node_to_last_usage[consumer].add(producer)
234+
235+
# loop through nodes in order of the forward graph and keep track of the following:
236+
# for each node, right before its execution, the output of what nodes are in memory.
237+
@dataclass
238+
class NodeCutScore:
239+
tot_mem: float
240+
alive_node_names: OrderedSet[str]
241+
242+
alive_nodes = OrderedSet()
243+
node2score = {}
244+
for node in fwd_recomputable_nodes:
245+
if not must_recompute(node):
246+
alive_nodes.add(node)
247+
for a in node_to_last_usage[node]:
248+
alive_nodes.remove(a)
249+
tot_mem = sum(_size_of(node) for node in alive_nodes)
250+
node2score[node] = NodeCutScore(
251+
tot_mem, OrderedSet([n.name for n in alive_nodes])
252+
)
253+
254+
# divide the graph into stages with roughly equal memory usage.
255+
stages = defaultdict(OrderedSet)
256+
cum_mem_so_far = 0
257+
curr_stage_idx = 0
258+
259+
if stage_size_in_GiB is None:
260+
total_used_memory = sum(
261+
_size_of(node)
262+
for node in fwd_recomputable_nodes
263+
if not must_recompute(node)
264+
)
265+
total_used_memory_in_GiB = total_used_memory / 2**30
266+
stage_size_in_GiB = total_used_memory_in_GiB**0.5
267+
print(f"Computed stage_size {stage_size_in_GiB=}")
268+
target_mem = stage_size_in_GiB * 2**30
269+
for node in fwd_recomputable_nodes:
270+
stages[curr_stage_idx].add(node)
271+
if not must_recompute(node):
272+
cum_mem_so_far += _size_of(node)
273+
if cum_mem_so_far >= target_mem:
274+
curr_stage_idx += 1
275+
cum_mem_so_far = 0
276+
277+
# loop through each stage and pick the best node to cut on, and save
278+
# the nodes that will be marked as must save.
279+
nodes_to_save = OrderedSet()
280+
for stage in stages.values():
281+
best_node = min(stage, key=lambda x: node2score[x].tot_mem)
282+
nodes_to_save.update(node2score[best_node].alive_node_names)
283+
mark_nodes_as_must_save([name_to_node_mapping[n] for n in nodes_to_save])
284+
285+
save_list = {
286+
torch.ops.aten.mm.default,
287+
# torch.ops.aten._scaled_dot_product_efficient_attention.default,
288+
# torch.ops.aten._scaled_dot_product_flash_attention.default,
289+
}
290+
must_save_nodes = []
291+
counter = 0
292+
for node in fwd_recomputable_nodes:
293+
if node.target in save_list:
294+
if node.target == torch.ops.aten.mm.default:
295+
if counter % 2 == 0:
296+
counter += 1
297+
else:
298+
counter += 1
299+
continue
300+
must_save_nodes.append(node)
301+
mark_nodes_as_must_save(must_save_nodes)
302+
303+
304+
def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0):
305+
force_recompute_fsdp_all_gather(graph)
306+
mark_nodes_as_must_save_to_stage_recomputation(
307+
graph, stage_size_in_GiB=ac_stage_size_in_GiB
308+
)

autoparallel/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch.fx import GraphModule
3030
from torch.fx.experimental._backward_state import BackwardState
3131

32+
from .activation_checkpointing import ac_joint_pass
3233
from .apply_sharding import apply_sharding_to_model
3334
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
3435
from .init_weights import hook_params_setters
@@ -203,6 +204,9 @@ def __init__(
203204
mesh: DeviceMesh,
204205
mp_policy: Optional[MixedPrecisionPolicy] = None,
205206
compile: bool = False,
207+
enable_ac: bool = True,
208+
# None means 'auto'
209+
ac_stage_size_in_GiB: Optional[float] = None,
206210
):
207211
self.stack = ExitStack()
208212
self.fake_mode = (
@@ -228,9 +232,10 @@ def __init__(
228232
self.input_fn = input_fn
229233
self.mesh = mesh
230234
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
235+
self.enable_ac = enable_ac
236+
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
231237

232238
# NB: rest of the construction happens in __enter__
233-
234239
self.active = False
235240

236241
def __enter__(self):
@@ -439,6 +444,9 @@ def apply_placement(self, sharding_placement=None):
439444
},
440445
payload_fn=lambda: str(parallel_gm.graph),
441446
)
447+
448+
if self.enable_ac:
449+
ac_joint_pass(parallel_gm.graph, self.ac_stage_size_in_GiB)
442450
# now rename input/param/tangent/output/grad_param/grad_input nodes following
443451
# our convention
444452
# apply_node_renaming(

0 commit comments

Comments
 (0)