Skip to content

Commit 897ad9d

Browse files
BoyuanFengSilv3S
authored andcommitted
Inductor Lite Mode (pytorch#167115)
This PR introduces inductor lite mode for opt-in optimizations and numeric correctness guarantees. Different from default mode that applies all possible fusions, lite mode gives the control back to user and provides guarantee on numeric correctness. Specifically, this mode: - **Fallback by Default**: Fallback for ALL nodes by default, unless users explicitly mark node for inductor fusion. - **Selective Decomposition**: Skip decomposition for all nodes except for user marked nodes. - **Regional inductor compile** - Skip dead code elimination - Skip buffer reues - Skip reorder passes, such as reorder for peak memory, reorder for compute comm overlap, and reorder_for_reducing_graph_partitions. - Skip all pre-grad, joint-graph, and post-grad passes. ## Example: Flex Attention ```python import torch import torch.fx.traceback as fx_traceback from torch.nn.attention.flex_attention import create_block_mask, flex_attention def _squared(score, b, h, m, n): return score * score def mask_mod(b, h, q, k): return q >= 0 a, b = 12, 64 block_mask = create_block_mask(mask_mod, None, None, a * b, a * b, device="cuda") def fn(x): x = torch.sin(x) with fx_traceback.annotate({"compile_with_inductor": 0}): x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared) return torch.cos(x) x = torch.randn(1, 1, a * b, b, dtype=torch.bfloat16, device="cuda", requires_grad=True) opt_fn = torch.compile(fn, mode="lite", fullgraph=True,) opt_fn(x) ``` [code diff](https://www.internalfb.com/intern/diffing/?paste_number=2027441476) [default mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpYAzDxX/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) vs [lite mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpnnuh1W/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) ## Numerics Inductor lite mode provides bitwise equivalence with `aot_eager` backend on torchtitan llama3-8b and DeepSeek v3. pytorch/torchtitan#2005 close: pytorch#167012 Pull Request resolved: pytorch#167115 Approved by: https://github.com/ezyang
1 parent 76100f9 commit 897ad9d

File tree

11 files changed

+486
-18
lines changed

11 files changed

+486
-18
lines changed

test/inductor/test_torchinductor.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131
import torch._dynamo.config as dynamo_config
3232
import torch._inductor.aoti_eager
33+
import torch.fx.traceback as fx_traceback
3334
import torch.nn as nn
3435
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
3536
from torch._dispatch.python import enable_python_dispatcher
@@ -13564,6 +13565,224 @@ def f(image_latent):
1356413565
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
1356513566
FileCheck().check_regex(size_assert_pattern).run(code)
1356613567

13568+
def test_lite_mode_fallback(self):
13569+
def f(x):
13570+
z = x.sin()
13571+
return z.cos()
13572+
13573+
f = torch.compile(f, mode="lite")
13574+
13575+
_, code = run_and_get_code(f, torch.randn(2, device=self.device))
13576+
13577+
# Checks that aten ops are kept and run
13578+
if config.cpp_wrapper:
13579+
FileCheck().check("aoti_torch_call_dispatcher(").check("aten::sin").check(
13580+
"aoti_torch_call_dispatcher("
13581+
).check("aten::cos").run(code[0])
13582+
else:
13583+
FileCheck().check("torch.ops.aten.sin.default(").check(
13584+
"torch.ops.aten.cos.default("
13585+
).run(code[0])
13586+
# Checks that no triton code run in the generated code
13587+
self.assertFalse(".run(" in code[0])
13588+
13589+
# skip cpu test since rms norm is always decomposed on cpu
13590+
def test_lite_mode_not_decompose(self):
13591+
if self.device != GPU_TYPE or self.device == "mps":
13592+
raise unittest.SkipTest("requires GPU")
13593+
13594+
def f(x, shape):
13595+
y = x + 1
13596+
z = torch.ops.aten._fused_rms_norm(y, shape, None, None)
13597+
return z[0] + z[1]
13598+
13599+
f = torch.compile(f, mode="lite")
13600+
13601+
x = torch.randn(2, 3, device=self.device)
13602+
_, code = run_and_get_code(f, x, [2, 3])
13603+
if config.cpp_wrapper:
13604+
FileCheck().check(
13605+
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda__fused_rms_norm("
13606+
).run(code[0])
13607+
else:
13608+
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
13609+
13610+
if config.cpp_wrapper:
13611+
# arg type List[int] is not yet supported by custom_op_wrapper
13612+
pass
13613+
else:
13614+
x = torch.randn(2, 3, device=self.device, requires_grad=True)
13615+
_, codes = run_fw_bw_and_get_code(lambda: f(x, [2, 3]))
13616+
self.assertEqual(len(codes), 2)
13617+
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
13618+
13619+
def test_lite_regional_compile_flex_attention(self):
13620+
if self.device != GPU_TYPE or self.device == "mps":
13621+
raise unittest.SkipTest("requires GPU")
13622+
13623+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
13624+
13625+
def _squared(score, b, h, m, n):
13626+
return score * score
13627+
13628+
def mask_mod(b, h, q, k):
13629+
return q >= 0
13630+
13631+
a = 12
13632+
b = 64
13633+
block_mask = create_block_mask(
13634+
mask_mod, None, None, a * b, a * b, device=self.device
13635+
)
13636+
13637+
def fn(x):
13638+
x = torch.sin(x)
13639+
with fx_traceback.annotate({"compile_with_inductor": 0}):
13640+
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
13641+
return torch.cos(x)
13642+
13643+
x = torch.randn(
13644+
1,
13645+
1,
13646+
a * b,
13647+
b,
13648+
dtype=torch.bfloat16,
13649+
device=self.device,
13650+
requires_grad=True,
13651+
)
13652+
13653+
opt_fn = torch.compile(
13654+
fn,
13655+
mode="lite",
13656+
fullgraph=True,
13657+
)
13658+
13659+
# Check that inductor compilation is called twice
13660+
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
13661+
self.assertEqual(len(codes), 2)
13662+
13663+
@unittest.skipIf(
13664+
config.cpp_wrapper,
13665+
"codegen invoke_subgraph is not implemented for cpp wrapper",
13666+
)
13667+
def test_lite_regional_compile_invoke_subgraph(self):
13668+
# Checks that get_attr nodes custom metadata is propagated
13669+
@torch.compiler.nested_compile_region
13670+
def gn(x):
13671+
return torch.sin(x)
13672+
13673+
def fn(x):
13674+
x = x + 1
13675+
with fx_traceback.annotate({"compile_with_inductor": 0}):
13676+
z = gn(x)
13677+
return torch.sigmoid(z)
13678+
13679+
opt_fn = torch.compile(fn, mode="lite", fullgraph=True)
13680+
x = torch.randn(10, requires_grad=True)
13681+
13682+
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
13683+
self.assertEqual(len(codes), 2)
13684+
13685+
@unittest.skipIf(
13686+
config.cpp_wrapper,
13687+
"codegen triton_kernel_wrapper_functional is not implemented for cpp wrapper",
13688+
)
13689+
def test_lite_triton_kernel_wrapper_functional(self):
13690+
if self.device != GPU_TYPE or self.device == "mps":
13691+
raise unittest.SkipTest("requires GPU")
13692+
13693+
from torch._higher_order_ops.triton_kernel_wrap import (
13694+
kernel_side_table,
13695+
triton_kernel_wrapper_functional,
13696+
)
13697+
from torch.testing._internal.triton_utils import mul2_kernel
13698+
13699+
kernel_side_table.reset_table()
13700+
13701+
def f(x, output):
13702+
out = triton_kernel_wrapper_functional(
13703+
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
13704+
constant_args_idx=kernel_side_table.add_constant_args(
13705+
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
13706+
),
13707+
grid=[(x.numel(),)],
13708+
tma_descriptor_metadata={},
13709+
kwargs={
13710+
"in_ptr0": x,
13711+
"out_ptr": output,
13712+
},
13713+
tensors_to_clone=["in_ptr0", "out_ptr"],
13714+
)
13715+
return out["out_ptr"]
13716+
13717+
t1 = torch.rand(5, device=self.device)
13718+
t2 = torch.rand(5, device=self.device)
13719+
13720+
compiled_f = torch.compile(f, mode="lite")
13721+
out = compiled_f(t1, t2)
13722+
13723+
# Make sure t2 was not modified
13724+
self.assertNotEqual(out, t2)
13725+
13726+
def test_lite_regional_compile_repeated_blocks(self):
13727+
def fn(x, y):
13728+
sin = torch.sin(x)
13729+
13730+
with fx_traceback.annotate({"compile_with_inductor": 0}):
13731+
mul = sin * y
13732+
add = mul + 1
13733+
13734+
return torch.sin(add)
13735+
13736+
class Mod(torch.nn.Module):
13737+
def __init__(self):
13738+
super().__init__()
13739+
13740+
def forward(self, x, y):
13741+
a = fn(x, y)
13742+
return fn(a, y)
13743+
13744+
mod = Mod()
13745+
13746+
opt_mod = torch.compile(
13747+
mod,
13748+
mode="lite",
13749+
fullgraph=True,
13750+
)
13751+
x = torch.randn(10, requires_grad=True)
13752+
y = torch.randn(10, requires_grad=True)
13753+
13754+
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
13755+
self.assertEqual(len(codes), 2)
13756+
13757+
def test_lite_dynamic_shape_assertion(self):
13758+
class Model(torch.nn.Module):
13759+
def forward(self, c):
13760+
d = torch.concat([c, c], dim=0)
13761+
with fx_traceback.annotate({"compile_with_inductor": "my_region"}):
13762+
d = d + 1
13763+
return d
13764+
13765+
model = Model()
13766+
model = torch.compile(
13767+
model,
13768+
mode="lite",
13769+
fullgraph=True,
13770+
)
13771+
13772+
c = torch.randn((64, 32), device=self.device)
13773+
torch._dynamo.decorators.mark_unbacked(c, 0)
13774+
13775+
_, code = run_and_get_code(model, c)
13776+
# Checks that unbacked symint assertions are kept
13777+
if config.cpp_wrapper:
13778+
FileCheck().check_regex(r"if \(!\(u.* >= 0L\)\)").check_regex(
13779+
"Expected u.* >= 0 but receive"
13780+
).run(code[0])
13781+
else:
13782+
FileCheck().check_regex(r"if not \(u.* >= 0\):").check_regex(
13783+
r"raise RuntimeError\('u.* >= 0'\)"
13784+
).run(code[0])
13785+
1356713786
@lowering.force_fallback(aten.sort.default)
1356813787
@unittest.skipIf(
1356913788
config.cpp_wrapper,

torch/_functorch/_aot_autograd/graph_compile.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@
103103
_thread_local = threading.local()
104104

105105

106+
@contextmanager
107+
def maybe_skip_decompose(aot_config: AOTConfig):
108+
old_decomp = aot_config.decompositions
109+
try:
110+
if config.selective_decompose:
111+
aot_config.decompositions = {}
112+
yield
113+
finally:
114+
aot_config.decompositions = old_decomp
115+
116+
106117
# Saved tensor hooks context
107118
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
108119
# for saved nodes from forward to backward. (E.g. activations quantization)
@@ -196,27 +207,46 @@ def orig_flat_fn2(*args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]:
196207
# deterministic TLS can be different
197208
aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
198209
updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]]
199-
if aot_state.needs_autograd and not aot_config.pre_dispatch:
200-
# FYI: this being moved to trigger in export is new, seems fine!
201-
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
210+
211+
with maybe_skip_decompose(aot_config):
212+
# if config.selective_decompose, skip decomposition and apply selective_decompose
213+
# after we get the joint graph. See [Note: Selective Decomposition] for details.
214+
if aot_state.needs_autograd and not aot_config.pre_dispatch:
215+
# FYI: this being moved to trigger in export is new, seems fine!
216+
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
217+
(
218+
graph,
219+
updated_flat_args,
220+
updated_flat_args_descs,
221+
maybe_subclass_meta,
222+
) = aot_dispatch_autograd_graph(
223+
flat_fn,
224+
aot_state.flat_args,
225+
aot_state.flat_args_descs,
226+
aot_config,
227+
fw_metadata=aot_state.fw_metadata,
228+
)
229+
else:
202230
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
203-
aot_dispatch_autograd_graph(
231+
aot_dispatch_base_graph(
204232
flat_fn,
205233
aot_state.flat_args,
206234
aot_state.flat_args_descs,
207235
aot_config,
208236
fw_metadata=aot_state.fw_metadata,
209237
)
210238
)
211-
else:
212-
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
213-
aot_dispatch_base_graph( # type: ignore[assignment]
214-
flat_fn,
215-
aot_state.flat_args,
216-
aot_state.flat_args_descs,
217-
aot_config,
218-
fw_metadata=aot_state.fw_metadata,
219-
)
239+
240+
if config.selective_decompose:
241+
from torch.fx.experimental.proxy_tensor import selective_decompose
242+
from torch.fx.passes.regional_inductor import _needs_inductor_compile
243+
244+
graph = selective_decompose(
245+
graph,
246+
*updated_flat_args,
247+
decomposition=aot_config.decompositions,
248+
should_decompose=_needs_inductor_compile,
249+
trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch,
220250
)
221251

222252
return AOTGraphCapture(

torch/_functorch/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ def remote_autograd_cache_default() -> Optional[bool]:
374374
# This callback is invoked on the joint graph before partitioning
375375
joint_custom_pass: Callable = None # type: ignore[assignment]
376376

377+
# Note [Selective Decomposition]
378+
# This config allows selective decomposition of certain operators in the graph.
379+
# When True, it does NOT decompose any nodes, except those nodes that users explicitly
380+
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
381+
# on to explicitly annotate. This is currently only used by inductor lite mode.
382+
selective_decompose: bool = False
383+
377384

378385
if TYPE_CHECKING:
379386
from torch.utils._config_typing import * # noqa: F401, F403

torch/_inductor/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,25 @@ def aot_compile(
315315
)
316316

317317

318+
lite_mode_options = {
319+
# Fallback by default unless users explicitly annotated with
320+
# regional inductor compile.
321+
"fallback_by_default": True,
322+
"selective_decompose": True,
323+
# Disable reorder optimizations
324+
"reorder_for_peak_memory": False,
325+
"reorder_for_compute_comm_overlap": False,
326+
"triton.reorder_for_reducing_graph_partitions": False,
327+
# Disable pre-, joint-, post-grad passes
328+
"use_pre_grad_passes": False,
329+
"use_joint_graph_passes": False,
330+
"use_post_grad_passes": False,
331+
# Disable dead code elimination (dce) and buffer reuse
332+
"use_dce": False,
333+
"allow_buffer_reuse": False,
334+
}
335+
336+
318337
def list_mode_options(
319338
mode: Optional[str] = None, dynamic: Optional[bool] = None
320339
) -> dict[str, Any]:
@@ -332,6 +351,8 @@ def list_mode_options(
332351

333352
mode_options: dict[str, dict[str, bool]] = {
334353
"default": {},
354+
# lite backend for opt-in optimizations
355+
"lite": lite_mode_options,
335356
# enable cudagraphs
336357
"reduce-overhead": {
337358
"triton.cudagraphs": True,

0 commit comments

Comments
 (0)