Skip to content

Commit f5e3a84

Browse files
authored
CUDAGraph support for SimpleFSDP and TP (#2050)
## Features - [x] Support SimpleFSDP and TP - [x] Support static input indices to reduce copy - [x] Support memory reuse to reduce memory consumption - [x] Cleanup cudagraph when training finishes to avoid nccl hang from destroy_process_group Command: ``` NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph ``` Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl + cudagraphs + expandable segments result in IMA. pytorch/pytorch#158029 [trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json) ## Result **Numerics:** Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B. **Performance:** <img width="560" height="90" alt="image" src="https://github.com/user-attachments/assets/9d54c461-0eb1-4f7e-9652-3d52043ad74f" /> Raw log: [llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190), [llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416) **Memory:** On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB). A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example): - Start: 161 GiB - \+ use the same stream for warmup and graph capture of both fwd and bwd: 160 GiB - \+ warmup in cudagraph memory pool instead of eager memory pool: 153 GiB **static input copy:** On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly. ## Followup PR In the followup PR, I will enable fx graph partition for deepseek v3 pytorch/pytorch#165945.
1 parent 8bf2265 commit f5e3a84

File tree

7 files changed

+292
-10
lines changed

7 files changed

+292
-10
lines changed

torchtitan/experiments/compiler_toolkit/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,9 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to
5555
```shell
5656
NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor
5757
```
58+
59+
**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph**
60+
61+
```shell
62+
NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph
63+
```

torchtitan/experiments/compiler_toolkit/common_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from contextlib import contextmanager
8+
from typing import Callable
89

910
import torch
1011
from torch.distributed.tensor import DTensor, Replicate
@@ -53,3 +54,11 @@ def register_blockmask_pytree_node():
5354
flatten_with_keys_fn=BlockMask._flatten_with_keys,
5455
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
5556
)
57+
58+
59+
def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
60+
return (
61+
len(passes) > 0
62+
and (last_pass_name := getattr(passes[-1], "__name__", None))
63+
and (last_pass_name in names)
64+
)
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
CUDAGraph pass for the compiler toolkit.
9+
10+
This module provides a cudagraph pass that can be applied to graph modules
11+
during compilation.
12+
"""
13+
14+
import warnings
15+
from typing import Any, Callable, Optional, Sequence
16+
17+
import torch
18+
from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager
19+
from torch.utils._ordered_set import OrderedSet
20+
21+
22+
def init_global_graph_pool() -> tuple[
23+
torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream
24+
]:
25+
dummy_graph = torch.cuda.CUDAGraph()
26+
27+
# create a global cudagraph memory pool to allow memory reuse across cudagraphs.
28+
graph_pool = torch.cuda.graph_pool_handle()
29+
30+
# create a global cuda stream for graph capture. we need to use a single stream
31+
# for all allocations to the memory pool, otherwise the allocations to separate streams
32+
# will not be used.
33+
graph_capture_stream = torch.cuda.Stream()
34+
35+
# use a dummy graph to keep the global graph pool alive
36+
with (
37+
# suppress an empty cudagraph warning, since we intentionally create
38+
# an empty cudagraph here
39+
warnings.catch_warnings(record=True),
40+
torch.cuda.graph(
41+
dummy_graph,
42+
pool=graph_pool,
43+
stream=graph_capture_stream,
44+
capture_error_mode="thread_local",
45+
),
46+
):
47+
pass
48+
49+
return dummy_graph, graph_pool, graph_capture_stream
50+
51+
52+
(
53+
_global_dummy_graph,
54+
_global_graph_pool,
55+
_global_graph_capture_stream,
56+
) = init_global_graph_pool()
57+
58+
59+
class CUDAGraphWrapper:
60+
def __init__(
61+
self,
62+
runnable: Callable,
63+
example_inputs: Sequence[Any],
64+
static_input_indices: Optional[tuple[int]] = None,
65+
should_check_address: bool = False,
66+
):
67+
self.runnable = runnable
68+
self.graph_pool = _global_graph_pool
69+
self.stream = _global_graph_capture_stream
70+
self.static_input_indices = OrderedSet(
71+
static_input_indices if static_input_indices is not None else []
72+
)
73+
self.input_indices_to_copy = [
74+
i
75+
for i, inp in enumerate(example_inputs)
76+
if isinstance(inp, torch.Tensor) and i not in self.static_input_indices
77+
]
78+
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
79+
self.has_warmup = False
80+
81+
self.args = None
82+
self.output = None
83+
84+
# (debug only) whether check static input tensor addresses during runtime
85+
self.should_check_address = should_check_address
86+
87+
def copy_non_static_inputs(self, *args):
88+
for i in self.input_indices_to_copy:
89+
self.args[i].copy_(args[i])
90+
91+
def check_input_types(self, inputs) -> None:
92+
for inp in inputs:
93+
assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), (
94+
"args must be tensor, integer (for dynamic shapes), "
95+
"or Generator (for random number generator), "
96+
f"but found {type(inp)}"
97+
)
98+
99+
def check_static_inputs_address(self) -> None:
100+
for i in self.static_input_indices:
101+
actual = args[i].data_ptr()
102+
expected = self.input_addresses[i]
103+
assert expected == actual, (
104+
"Expected the same static tensor address but found "
105+
f"{expected} != {actual}"
106+
)
107+
108+
def __call__(self, *args):
109+
if not self.has_warmup:
110+
self.has_warmup = True
111+
device = torch.cuda.current_device()
112+
113+
# warmup in cudagraph memory pool to avoid fragmentation
114+
# across eager memory pool and cudagraph memory pool.
115+
with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream):
116+
out = self.runnable(*args)
117+
return out
118+
119+
if self.cudagraph is None:
120+
self.check_input_types(args)
121+
self.args = args
122+
self.input_addresses = [
123+
x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args
124+
]
125+
126+
self.cudagraph = torch.cuda.CUDAGraph()
127+
128+
with torch.cuda.graph(
129+
self.cudagraph, pool=self.graph_pool, stream=self.stream
130+
):
131+
# `output` is managed by pytorch's cudagraph pool
132+
self.output = self.runnable(*args)
133+
134+
if self.should_check_address:
135+
self.check_static_inputs_address()
136+
137+
self.copy_non_static_inputs(*args)
138+
self.cudagraph.replay()
139+
return self.output
140+
141+
142+
def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]:
143+
"""
144+
Get indices of gm inputs that are static input tensors whose tensor addresses do not
145+
change across runs. Example of static input tensors include weights, buffers, and
146+
outputs of previous cudagraph wrapped functions.
147+
"""
148+
from torch._inductor.utils import count_tangents
149+
150+
static_input_indices = []
151+
if (
152+
is_forward
153+
and (tracing_context := torch._guards.TracingContext.try_get())
154+
and hasattr(tracing_context, "fw_metadata")
155+
):
156+
# for forward, we rely on graph capture (i.e., dynamo or export) to provide
157+
# the correct static input indices stored in tracing context. Typical examples
158+
# include weights and buffers.
159+
static_input_indices = tracing_context.fw_metadata.static_input_indices
160+
161+
elif not is_forward:
162+
# for backward, we identify saved tensors as static inputs, since saved tensors
163+
# are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm,
164+
# saved tensors are always the leading args. So we can get the number of saved
165+
# tensors and generate static input indices.
166+
fixed = count_tangents(gm)
167+
static_input_indices = list(range(fixed))
168+
169+
return static_input_indices

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.distributed.tensor import DTensor
2121
from torchtitan.config import JobConfig
2222
from torchtitan.distributed import ParallelDims
23+
from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass
2324
from torchtitan.tools.logging import logger
2425

2526

@@ -217,6 +218,7 @@ def compiler(
217218
example_inputs,
218219
passes: List[Callable] = None,
219220
dump_folder: str | None = None,
221+
is_forward: bool = True,
220222
):
221223
"""
222224
Compile a graph module by applying a sequence of compiler passes.
@@ -239,6 +241,17 @@ def compiler(
239241
)
240242
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
241243

244+
if end_with_pass(passes, ["cudagraph_pass"]):
245+
# cudagraph pass is always the last pass if it is applied
246+
cg_pass = passes[-1]
247+
248+
# to identify static input indices, cudagraph passes behaves differently for
249+
# forward and backward pass. so we explicitly pass the info.
250+
_cg_pass = functools.partial(cg_pass, is_forward=is_forward)
251+
252+
# keep the function name for debug log
253+
passes[-1] = functools.wraps(cg_pass)(_cg_pass)
254+
242255
for pass_fn in passes:
243256
pass_name = (
244257
pass_fn.func.__name__
@@ -271,17 +284,42 @@ def make_compiler_with_passes(
271284

272285
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
273286
return compiler(
274-
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
287+
"fwd_gm",
288+
gm,
289+
example_inputs,
290+
passes=passes,
291+
dump_folder=dump_folder,
292+
is_forward=True,
275293
)
276294

277295
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
278296
return compiler(
279-
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
297+
"bwd_gm",
298+
gm,
299+
example_inputs,
300+
passes=passes,
301+
dump_folder=dump_folder,
302+
is_forward=False,
280303
)
281304

282305
return fw_compiler, bw_compiler
283306

284307

308+
def validate_pass_names(pass_names: list[str]) -> None:
309+
if "cudagraph" in pass_names:
310+
assert (
311+
pass_names[-1] == "cudagraph"
312+
), "cudagraph has to be the last pass to apply"
313+
314+
if (
315+
"autobucketing_reordering" in pass_names
316+
and "transformer_block_bucketing" in pass_names
317+
):
318+
raise ValueError(
319+
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
320+
)
321+
322+
285323
def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
286324
"""
287325
Extract and validate compiler passes from job config.
@@ -298,13 +336,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
298336
)
299337

300338
pass_names = getattr(job_config.compile, "passes", [])
301-
if (
302-
"autobucketing_reordering" in pass_names
303-
and "transformer_block_bucketing" in pass_names
304-
):
305-
raise ValueError(
306-
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
307-
)
339+
validate_pass_names(pass_names)
308340
compiler_passes = []
309341

310342
for pass_name in pass_names:

torchtitan/experiments/compiler_toolkit/passes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
during compilation. Passes can be selected and configured via job config.
1212
"""
1313

14+
from typing import Any, Sequence
15+
1416
import torch
1517
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
1618
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
1719
from torch.fx.passes.regional_inductor import regional_inductor
20+
from torchtitan.experiments.compiler_toolkit.cudagraph import (
21+
CUDAGraphWrapper,
22+
get_static_input_indices,
23+
)
1824
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
1925
annotate_fsdp_all_gather,
2026
)
@@ -56,6 +62,23 @@ def regional_inductor_pass(
5662
return regional_inductor(gm, example_inputs)
5763

5864

65+
def cudagraph_pass(
66+
gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool
67+
) -> torch.fx.GraphModule:
68+
"""
69+
Apply cudagraph.
70+
71+
This pass wraps the forward function with cudagraph during compilation and does
72+
not record cudagraph until runtime.
73+
- For the first run, it will warm up operators such as nccl.
74+
- For the second run, it will record cudagraph and replay cudagraph.
75+
- For the following runs, it will replay cudagraph.
76+
"""
77+
static_input_indices = get_static_input_indices(gm, is_forward)
78+
gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices)
79+
return gm
80+
81+
5982
def validate_flex_attn_annotation_pass(
6083
gm: torch.fx.GraphModule,
6184
) -> torch.fx.GraphModule:
@@ -88,4 +111,5 @@ def fsdp_reshard_after_fwd_pass(
88111
"autobucketing_reordering": autobucketing_reordering_pass,
89112
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
90113
"regional_inductor": regional_inductor_pass,
114+
"cudagraph": cudagraph_pass,
91115
}

torchtitan/experiments/compiler_toolkit/tests/integration_tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
5858
"llama3_fsdp_tp_manualbucketing",
5959
ngpu=4,
6060
),
61+
OverrideDefinitions(
62+
[
63+
[
64+
"--model.name compiler_toolkit.llama3",
65+
"--parallelism.data_parallel_shard_degree 2",
66+
"--parallelism.tensor_parallel_degree 2",
67+
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
68+
"--compile.passes cudagraph",
69+
],
70+
],
71+
"llama3 FSDP+TP+cudagraph",
72+
"llama3_fsdp_tp_cudagraph",
73+
ngpu=4,
74+
),
6175
OverrideDefinitions(
6276
[
6377
[
@@ -86,6 +100,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]:
86100
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor",
87101
ngpu=4,
88102
),
103+
OverrideDefinitions(
104+
[
105+
[
106+
"--model.name compiler_toolkit.llama3",
107+
"--parallelism.data_parallel_shard_degree 2",
108+
"--parallelism.tensor_parallel_degree 2",
109+
"--model.flavor debugmodel_flex_attn",
110+
"--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config",
111+
"--compile.passes autobucketing_reordering,regional_inductor,cudagraph",
112+
],
113+
],
114+
"llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph",
115+
"llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph",
116+
ngpu=4,
117+
),
89118
OverrideDefinitions(
90119
[
91120
[

0 commit comments

Comments
 (0)