From 7e8185f24db7c41f7dd70b2e51dce4bc0e8bdfdb Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 7 Nov 2025 17:41:49 -0800 Subject: [PATCH 01/15] add lite mode --- .../compiler_toolkit/graph_utils.py | 46 ++++++++++++++++++- .../compiler_toolkit/llama3/parallelize.py | 10 ++-- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 4ff6c8187b..fe3f07d9af 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -106,7 +106,7 @@ def joint_graph_builder( if joint_custom_pass is not None: joint_custom_pass(joint_with_descriptors) - with tracing(tracing_context): + with tracing(tracing_context), torch._functorch.config.patch(selective_decompose=True): fn = aot_compile_joint_with_descriptors( joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler ) @@ -122,6 +122,50 @@ def wrapper_fn(args, kwargs): return wrapper_fn +def get_inductor_lite_fw_compiler(): + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor import lite_mode_options + + context = torch._guards.TracingContext.try_get() + + if not context or not context.fw_metadata: + logger.warn("No context or fw_metadata available") + static_input_idxs = () + else: + static_input_idxs = context.fw_metadata.static_input_indices + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + with torch._inductor.config.patch(lite_mode_options): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=static_input_idxs, + is_backward=False, + ) + return compiled_fn + + return fw_compiler + + +def get_inductor_lite_bw_compiler(): + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor import lite_mode_options + from torch._inductor.utils import count_tangents + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + fixed = count_tangents(gm) + + with torch._inductor.config.patch(lite_mode_options): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + is_backward=True, + ) + return compiled_fn + + return bw_compiler + class CompiledModule(torch.nn.Module): def __init__( self, diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 0ed8452148..f62efc3698 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -24,6 +24,8 @@ from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, joint_graph_builder, + get_inductor_lite_fw_compiler, + get_inductor_lite_bw_compiler, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, @@ -53,12 +55,12 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) - + gm = compiler("fwd_gm", gm, example_inputs) + return get_inductor_lite_fw_compiler()(gm, example_inputs) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) - + gm = compiler("bwd_gm", gm, example_inputs) + return get_inductor_lite_bw_compiler()(gm, example_inputs) def validate_flex_attention_annotation(joint_with_descriptors): """Verify user annotations show up in the graph.""" From cd6a18791f639cbe94943d95f1bbaeb777a7b20b Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Sun, 9 Nov 2025 21:56:39 -0800 Subject: [PATCH 02/15] add lite mode for dsv3 --- .../deepseek_v3/parallelize.py | 14 ++++++++-- .../compiler_toolkit/graph_utils.py | 27 +++++++++++++------ .../compiler_toolkit/llama3/parallelize.py | 6 +++-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 5c8ffb45c5..1d09bafafc 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -21,6 +21,8 @@ from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, + get_inductor_lite_bw_compiler, + get_inductor_lite_fw_compiler, joint_graph_builder, ) @@ -43,11 +45,19 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) + gm = compiler("fwd_gm", gm, example_inputs) + + # TODO: fix inductor size assertion for all_reduce + extra_inductor_config = {"size_asserts": False} + return get_inductor_lite_fw_compiler(extra_inductor_config)(gm, example_inputs) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) + gm = compiler("bwd_gm", gm, example_inputs) + + # TODO: fix inductor size assertion for all_reduce + extra_inductor_config = {"size_asserts": False} + return get_inductor_lite_bw_compiler(extra_inductor_config)(gm, example_inputs) def annotate_deepseekv3() -> None: diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index fe3f07d9af..87ac18d243 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -106,7 +106,9 @@ def joint_graph_builder( if joint_custom_pass is not None: joint_custom_pass(joint_with_descriptors) - with tracing(tracing_context), torch._functorch.config.patch(selective_decompose=True): + with tracing(tracing_context), torch._functorch.config.patch( + selective_decompose=True + ): fn = aot_compile_joint_with_descriptors( joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler ) @@ -122,9 +124,9 @@ def wrapper_fn(args, kwargs): return wrapper_fn -def get_inductor_lite_fw_compiler(): - from torch._inductor.compile_fx import compile_fx_inner +def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None): from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner context = torch._guards.TracingContext.try_get() @@ -134,8 +136,12 @@ def get_inductor_lite_fw_compiler(): else: static_input_idxs = context.fw_metadata.static_input_indices + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): - with torch._inductor.config.patch(lite_mode_options): + with torch._inductor.config.patch(inductor_config): compiled_fn = compile_fx_inner( gm, example_inputs, @@ -147,15 +153,19 @@ def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): return fw_compiler -def get_inductor_lite_bw_compiler(): - from torch._inductor.compile_fx import compile_fx_inner +def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None): from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.utils import count_tangents + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): fixed = count_tangents(gm) - - with torch._inductor.config.patch(lite_mode_options): + + with torch._inductor.config.patch(inductor_config): compiled_fn = compile_fx_inner( gm, example_inputs, @@ -166,6 +176,7 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): return bw_compiler + class CompiledModule(torch.nn.Module): def __init__( self, diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index f62efc3698..3bae88d1bc 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -23,9 +23,9 @@ from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, - joint_graph_builder, - get_inductor_lite_fw_compiler, get_inductor_lite_bw_compiler, + get_inductor_lite_fw_compiler, + joint_graph_builder, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, @@ -58,10 +58,12 @@ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: gm = compiler("fwd_gm", gm, example_inputs) return get_inductor_lite_fw_compiler()(gm, example_inputs) + def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: gm = compiler("bwd_gm", gm, example_inputs) return get_inductor_lite_bw_compiler()(gm, example_inputs) + def validate_flex_attention_annotation(joint_with_descriptors): """Verify user annotations show up in the graph.""" for node in joint_with_descriptors.graph_module.graph.nodes: From f6307da5d96f0ac36ea6d0d3e0680ba671dfbbe6 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 14:13:03 -0800 Subject: [PATCH 03/15] fix conflict --- .../deepseek_v3/parallelize.py | 9 ++- .../compiler_toolkit/graph_utils.py | 71 ++++--------------- .../compiler_toolkit/inductor_lite.py | 69 ++++++++++++++++++ .../compiler_toolkit/llama3/parallelize.py | 9 ++- 4 files changed, 99 insertions(+), 59 deletions(-) create mode 100644 torchtitan/experiments/compiler_toolkit/inductor_lite.py diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 982843bb24..704499f231 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -25,6 +25,8 @@ get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, + GraphBuilderOptions, + is_using_inductor_lite, ) from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( @@ -87,13 +89,18 @@ def parallelize_deepseekv3( compiler_passes, dump_folder=job_config.job.dump_folder ) + options = GraphBuilderOptions( + dump_folder = job_config.job.dump_folder, + use_inductor_lite = is_using_inductor_lite(job_config), + ) + # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, - dump_folder=job_config.job.dump_folder, + options=options, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 0756dcec64..ed779a72b3 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -23,6 +23,11 @@ from torchtitan.tools.logging import logger +class GraphBuilderOptions: + dump_folder: str | None = None + use_inductor_lite: bool = False + + def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: # TODO: make the dump rank configurable if not dump_folder or torch.distributed.get_rank() != 0: @@ -88,7 +93,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_passes: Optional[List[Callable]] = None, - dump_folder: str | None = None, + options: GraphBuilderOptions = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -100,7 +105,7 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_passes: list of custom passes to run on the joint graph - dump_folder: Optional folder to dump the graph to + options: Optional configs for graph builder """ assert isinstance(model_args, tuple) for idx, arg in enumerate(model_args): @@ -110,7 +115,7 @@ def joint_graph_builder( ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) + ) = export_joint(model, model_args, model_kwargs, dump_folder=options.dump_folder) # Optional validation if joint_custom_passes is not None: @@ -120,7 +125,7 @@ def joint_graph_builder( ) with tracing(tracing_context), torch._functorch.config.patch( - selective_decompose=True + selective_decompose=options.use_inductor_lite ): fn = aot_compile_joint_with_descriptors( joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler @@ -137,59 +142,6 @@ def wrapper_fn(args, kwargs): return wrapper_fn -def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None): - from torch._inductor import lite_mode_options - from torch._inductor.compile_fx import compile_fx_inner - - context = torch._guards.TracingContext.try_get() - - if not context or not context.fw_metadata: - logger.warn("No context or fw_metadata available") - static_input_idxs = () - else: - static_input_idxs = context.fw_metadata.static_input_indices - - inductor_config = lite_mode_options - if extra_config: - inductor_config.update(extra_config) - - def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): - with torch._inductor.config.patch(inductor_config): - compiled_fn = compile_fx_inner( - gm, - example_inputs, - static_input_idxs=static_input_idxs, - is_backward=False, - ) - return compiled_fn - - return fw_compiler - - -def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None): - from torch._inductor import lite_mode_options - from torch._inductor.compile_fx import compile_fx_inner - from torch._inductor.utils import count_tangents - - inductor_config = lite_mode_options - if extra_config: - inductor_config.update(extra_config) - - def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): - fixed = count_tangents(gm) - - with torch._inductor.config.patch(inductor_config): - compiled_fn = compile_fx_inner( - gm, - example_inputs, - static_input_idxs=list(range(fixed)), - is_backward=True, - ) - return compiled_fn - - return bw_compiler - - class CompiledModule(torch.nn.Module): def __init__( self, @@ -403,3 +355,8 @@ def get_joint_custom_passes_from_config( ) return joint_custom_passes + + +def is_using_inductor_lite(job_config: JobConfig) -> bool: + pass_names = getattr(job_config.compile, "passes", []) + return "inductor_lite" in pass_names diff --git a/torchtitan/experiments/compiler_toolkit/inductor_lite.py b/torchtitan/experiments/compiler_toolkit/inductor_lite.py new file mode 100644 index 0000000000..4344093943 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/inductor_lite.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Inductor lite pass for the compiler toolkit. + +This module provides inductor lite pass that can be applied to graph modules +during compilation. +""" +from typing import Any, Callable, List, Optional + +import torch +from torchtitan.tools.logging import logger + + +def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None): + from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner + + context = torch._guards.TracingContext.try_get() + + if not context or not context.fw_metadata: + logger.warn("No context or fw_metadata available") + static_input_idxs = () + else: + static_input_idxs = context.fw_metadata.static_input_indices + + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + with torch._inductor.config.patch(inductor_config): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=static_input_idxs, + is_backward=False, + ) + return compiled_fn + + return fw_compiler + + +def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None): + from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.utils import count_tangents + + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + fixed = count_tangents(gm) + + with torch._inductor.config.patch(inductor_config): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + is_backward=True, + ) + return compiled_fn + + return bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e746c24228..93e3d786b8 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -24,6 +24,8 @@ get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, + GraphBuilderOptions, + is_using_inductor_lite, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, @@ -74,13 +76,18 @@ def parallelize_llama( compiler_passes, dump_folder=job_config.job.dump_folder ) + options = GraphBuilderOptions( + dump_folder = job_config.job.dump_folder, + use_inductor_lite = is_using_inductor_lite(job_config), + ) + # Create custom joint_graph_builder with llama-specific compilers llama_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, - dump_folder=job_config.job.dump_folder, + options=options, ) # TODO: CompiledModule should take sample input as well, so that we can From 507e823fb6deccffb0bf363c844071c1a42be929 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 14:48:33 -0800 Subject: [PATCH 04/15] refactor --- .../deepseek_v3/parallelize.py | 8 +-- .../compiler_toolkit/graph_utils.py | 49 ++++++++++++++++++- .../compiler_toolkit/inductor_lite.py | 2 +- .../compiler_toolkit/llama3/parallelize.py | 8 +-- .../experiments/compiler_toolkit/passes.py | 29 +++++++++++ 5 files changed, 85 insertions(+), 11 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 704499f231..a6b3ee8915 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -23,10 +23,10 @@ CompiledModule, get_compiler_passes_from_config, get_joint_custom_passes_from_config, - joint_graph_builder, - make_compiler_with_passes, GraphBuilderOptions, is_using_inductor_lite, + joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( @@ -90,8 +90,8 @@ def parallelize_deepseekv3( ) options = GraphBuilderOptions( - dump_folder = job_config.job.dump_folder, - use_inductor_lite = is_using_inductor_lite(job_config), + dump_folder=job_config.job.dump_folder, + use_inductor_lite=is_using_inductor_lite(job_config), ) # Create custom joint_graph_builder with deepseekv3-specific compilers diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index ed779a72b3..cb85dcea88 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -10,6 +10,7 @@ from typing import Any, Callable, List, Optional import torch +from torch._dynamo.exc import Unsupported from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, @@ -40,6 +41,10 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No ) +def move_value_to_end(lst: list[Any], value: Any) -> None: + return [x for x in lst if x != value] + [x for x in lst if x == value] + + def export_joint( model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: @@ -224,6 +229,7 @@ def compiler( example_inputs, passes: List[Callable] = None, dump_folder: str | None = None, + is_forward: bool = True, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -240,6 +246,20 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES + if ( + len(passes) > 0 + and (last_pass := passes[-1].__name__) + and (last_pass == "cudagraph" or last_pass == "inductor_lite") + ): + # cudagraph pass or inductor lite pass is always the last pass if it is applied + # these two passes behaves differently for forward and backwawrd. so we explicitly + # pass the info. For example, different methods are used to identify static input + # indices. + last_pass = functools.partial(last_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(last_pass)(last_pass) + logger.debug(f"{name} before compiler:") logger.debug( gm.print_readable(print_output=False, include_stride=True, include_device=True) @@ -273,12 +293,22 @@ def make_compiler_with_passes( def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, ) return fw_compiler, bw_compiler @@ -297,6 +327,21 @@ def get_compiler_passes_from_config(job_config: JobConfig): from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES pass_names = getattr(job_config.compile, "passes", []) + + if "cudagraph" in pass_names and "inductor_lite" in pass_names: + raise Unsupported("cudagraph and inductor_lite cannot be applied together.") + elif "cudagrpah" in pass_names: + # cudagraph pass has to be the last pass + move_value_to_end(pass_names, "cudagraph") + elif "inductor_lite" in pass_names: + # inductor lite supports regional_inductor by default. They share the same + # user-facing frontend API (i.e., the context manager), uses different + # backend implementations, and achieves the same compilation result. + pass_names.remove("regional_inductor") + + # inductor lite pass has to be the last pass + move_value_to_end(pass_names, "inductor_lite") + compiler_passes = [] for pass_name in pass_names: diff --git a/torchtitan/experiments/compiler_toolkit/inductor_lite.py b/torchtitan/experiments/compiler_toolkit/inductor_lite.py index 4344093943..a25c01b08f 100644 --- a/torchtitan/experiments/compiler_toolkit/inductor_lite.py +++ b/torchtitan/experiments/compiler_toolkit/inductor_lite.py @@ -10,7 +10,7 @@ This module provides inductor lite pass that can be applied to graph modules during compilation. """ -from typing import Any, Callable, List, Optional +from typing import Optional import torch from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 93e3d786b8..22a971de26 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -22,10 +22,10 @@ CompiledModule, get_compiler_passes_from_config, get_joint_custom_passes_from_config, - joint_graph_builder, - make_compiler_with_passes, GraphBuilderOptions, is_using_inductor_lite, + joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, @@ -77,8 +77,8 @@ def parallelize_llama( ) options = GraphBuilderOptions( - dump_folder = job_config.job.dump_folder, - use_inductor_lite = is_using_inductor_lite(job_config), + dump_folder=job_config.job.dump_folder, + use_inductor_lite=is_using_inductor_lite(job_config), ) # Create custom joint_graph_builder with llama-specific compilers diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index c0cec614a9..e6259533a5 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,9 +11,15 @@ during compilation. Passes can be selected and configured via job config. """ +from typing import Callable + import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.compiler_toolkit.inductor_lite import ( + get_inductor_lite_bw_compiler, + get_inductor_lite_fw_compiler, +) from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) @@ -69,8 +75,31 @@ def fsdp_reshard_after_fwd_pass( return gm +def inductor_lite_pass( + gm: torch.fx.GraphModule, example_inputs, is_forward: bool +) -> Callable: + """ + Apply inductor lite mode. + + This pass takes a gm and generates a callable (not gm) using inductor. The lite + mode falls back for all ops except explicitly user-annotated ops under + regional compile. + """ + # TODO: fix inductor size assertion for all_reduce + # https://github.com/pytorch/pytorch/issues/167430 + extra_inductor_config = {"size_asserts": False} + + if is_forward: + _compiler = get_inductor_lite_fw_compiler(extra_inductor_config) + else: + _compiler = get_inductor_lite_bw_compiler(extra_inductor_config) + + return _compiler(gm, example_inputs) + + # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "regional_inductor": regional_inductor_pass, + "inductor_lite": inductor_lite_pass, } From 234eda84695b4677cf8f6fde7a3f8a248aa7a2ee Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 14:50:50 -0800 Subject: [PATCH 05/15] lint --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 0bafa25fb2..5ed0c7a023 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -10,7 +10,6 @@ from typing import Any, Callable, List, Optional import torch -from torch._dynamo.exc import Unsupported from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, From 215804e73c877eb7cf135cb48d5cd6674dfeb369 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 15:19:46 -0800 Subject: [PATCH 06/15] refactor --- .../compiler_toolkit/graph_utils.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 5ed0c7a023..4b05c93d11 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import dataclasses import functools from pathlib import Path from typing import Any, Callable, List, Optional @@ -23,6 +24,7 @@ from torchtitan.tools.logging import logger +@dataclasses.dataclass(frozen=True) class GraphBuilderOptions: dump_folder: str | None = None use_inductor_lite: bool = False @@ -247,17 +249,20 @@ def compiler( if ( len(passes) > 0 - and (last_pass := passes[-1].__name__) - and (last_pass == "cudagraph" or last_pass == "inductor_lite") + and (last_pass_name := passes[-1].__name__) + and ( + last_pass_name == "cudagraph_pass" or last_pass_name == "inductor_lite_pass" + ) ): # cudagraph pass or inductor lite pass is always the last pass if it is applied # these two passes behaves differently for forward and backwawrd. so we explicitly # pass the info. For example, different methods are used to identify static input # indices. - last_pass = functools.partial(last_pass, is_forward=is_forward) + last_pass = passes[-1] + _last_pass = functools.partial(last_pass, is_forward=is_forward) # keep the function name for debug log - passes[-1] = functools.wraps(last_pass)(last_pass) + passes[-1] = functools.wraps(last_pass)(_last_pass) logger.debug(f"{name} before compiler:") logger.debug( @@ -274,11 +279,20 @@ def compiler( logger.info(f"Applying pass: {pass_name}") gm = pass_fn(gm, example_inputs) - logger.debug(f"{name} after compiler:") - logger.debug( - gm.print_readable(print_output=False, include_stride=True, include_device=True) - ) - _dump_gm(dump_folder, gm, f"{name}_after_compiler") + if ( + len(passes) > 0 + and (last_pass_name := passes[-1].__name__) + and (last_pass_name != "inductor_lite_pass") + ): + # inductor lite mode returns a CompiledFxGraph which does not support print_readable. + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") + return gm @@ -344,7 +358,8 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi # inductor lite supports regional_inductor by default. They share the same # user-facing frontend API (i.e., the context manager), uses different # backend implementations, and achieves the same compilation result. - pass_names.remove("regional_inductor") + if "regional_inductor" in pass_names: + pass_names.remove("regional_inductor") # inductor lite pass has to be the last pass move_value_to_end(pass_names, "inductor_lite") From f790c61ba2f471f07e360a5efb060d3d277cc54d Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 16:31:54 -0800 Subject: [PATCH 07/15] nit --- .../compiler_toolkit/graph_utils.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 4b05c93d11..3f119853e0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -46,6 +46,14 @@ def move_value_to_end(lst: list[Any], value: Any) -> None: return [x for x in lst if x != value] + [x for x in lst if x == value] +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) + + def export_joint( model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: @@ -247,13 +255,7 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - if ( - len(passes) > 0 - and (last_pass_name := passes[-1].__name__) - and ( - last_pass_name == "cudagraph_pass" or last_pass_name == "inductor_lite_pass" - ) - ): + if end_with_pass(passes, ["cudagraph_pass", "inductor_lite_pass"]): # cudagraph pass or inductor lite pass is always the last pass if it is applied # these two passes behaves differently for forward and backwawrd. so we explicitly # pass the info. For example, different methods are used to identify static input @@ -279,11 +281,7 @@ def compiler( logger.info(f"Applying pass: {pass_name}") gm = pass_fn(gm, example_inputs) - if ( - len(passes) > 0 - and (last_pass_name := passes[-1].__name__) - and (last_pass_name != "inductor_lite_pass") - ): + if not end_with_pass(passes, ["inductor_lite_pass"]): # inductor lite mode returns a CompiledFxGraph which does not support print_readable. logger.debug(f"{name} after compiler:") logger.debug( From e887be7f14715b8509fed410305cf435e2f811ae Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 16:35:01 -0800 Subject: [PATCH 08/15] add test --- .../compiler_toolkit/tests/integration_tests.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index b0155a9f2a..e8f06b031b 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -133,6 +133,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep_flexattention", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering", + ], + ], + "llama3 FSDP+TP+inductor_lite", + "llama3_fsdp_tp_inductor_lite", + ngpu=4, + ), ] return integration_tests_flavors From 0b4314046de94cf9d08b007904c125dcbc510646 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 16:37:17 -0800 Subject: [PATCH 09/15] nit --- .../tests/integration_tests.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index e8f06b031b..6c0ddc845a 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -101,6 +101,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_manualbucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes inductor_lite", + ], + ], + "llama3 FSDP+TP+inductor_lite", + "llama3_fsdp_tp_inductor_lite", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ @@ -133,20 +147,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep_flexattention", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--model.name compiler_toolkit.llama3", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", - "--compile.passes autobucketing_reordering", - ], - ], - "llama3 FSDP+TP+inductor_lite", - "llama3_fsdp_tp_inductor_lite", - ngpu=4, - ), ] return integration_tests_flavors From 26b4e1dc13dd537ab5ea2074a3351c292af4cf94 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 16:40:30 -0800 Subject: [PATCH 10/15] nit --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 3f119853e0..7d80d68b00 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -354,8 +354,8 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi move_value_to_end(pass_names, "cudagraph") elif "inductor_lite" in pass_names: # inductor lite supports regional_inductor by default. They share the same - # user-facing frontend API (i.e., the context manager), uses different - # backend implementations, and achieves the same compilation result. + # user-facing frontend API (i.e., the context manager), use different + # backend implementations, and achieve the same compilation result. if "regional_inductor" in pass_names: pass_names.remove("regional_inductor") @@ -369,7 +369,6 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi raise ValueError( "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" ) - compiler_passes = [] for pass_name in pass_names: From bcf079e56b60860812b92511966ca82e2f8c52fb Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 21:59:48 -0800 Subject: [PATCH 11/15] more tests and fixes --- .../experiments/compiler_toolkit/README.md | 11 ++++ .../compiler_toolkit/common_utils.py | 9 ++++ .../compiler_toolkit/graph_utils.py | 51 ++++++++----------- .../tests/integration_tests.py | 30 +++++++++++ 4 files changed, 70 insertions(+), 31 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index c223d1e658..fb588803bb 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -22,6 +22,12 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn ``` +**SimpleFSDP + TP + EP + Inductor Lite** +```shell +NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +``` + + ## llama3 **SimpleFSDP + TP** @@ -39,6 +45,11 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` +**SimpleFSDP + TP + transformer-block-bucketing** +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +``` + **SimpleFSDP + TP + FlexAttention** ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..997af9a2c4 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from contextlib import contextmanager +from typing import Callable import torch from torch.distributed.tensor import DTensor, Replicate @@ -53,3 +54,11 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 7d80d68b00..dc2ffc7196 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -21,6 +21,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass from torchtitan.tools.logging import logger @@ -42,18 +43,6 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No ) -def move_value_to_end(lst: list[Any], value: Any) -> None: - return [x for x in lst if x != value] + [x for x in lst if x == value] - - -def end_with_pass(passes: list[Callable], names: list[str]) -> bool: - return ( - len(passes) > 0 - and (last_pass_name := getattr(passes[-1], "__name__", None)) - and (last_pass_name in names) - ) - - def export_joint( model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: @@ -255,10 +244,10 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - if end_with_pass(passes, ["cudagraph_pass", "inductor_lite_pass"]): - # cudagraph pass or inductor lite pass is always the last pass if it is applied - # these two passes behaves differently for forward and backwawrd. so we explicitly - # pass the info. For example, different methods are used to identify static input + if end_with_pass(passes, ["inductor_lite_pass"]): + # inductor lite pass is always the last pass if it is applied since it + # behaves differently for forward and backwawrd. so we explicitly pass + # the info. For example, different methods are used to identify static input # indices. last_pass = passes[-1] _last_pass = functools.partial(last_pass, is_forward=is_forward) @@ -330,6 +319,20 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return fw_compiler, bw_compiler +def validate_pass_names(pass_names: list[str]) -> None: + if "inductor_lite" in pass_names: + # inductor lite supports regional_inductor by default. They share the same + # user-facing frontend API (i.e., the context manager), use different + # backend implementations, and achieve the same compilation result. + assert "regional_inductor" not in pass_names, ( + "inductor_lite uses regional_inductor by default. please use one " + "pass at a time." + ) + assert ( + pass_names[-1] == "inductor_lite" + ), "inductor_lite has to be the last pass to apply" + + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. @@ -346,21 +349,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi ) pass_names = getattr(job_config.compile, "passes", []) - - if "cudagraph" in pass_names and "inductor_lite" in pass_names: - raise ValueError("Cannot apply cudagraph and inductor_lite at the same time!") - elif "cudagrpah" in pass_names: - # cudagraph pass has to be the last pass - move_value_to_end(pass_names, "cudagraph") - elif "inductor_lite" in pass_names: - # inductor lite supports regional_inductor by default. They share the same - # user-facing frontend API (i.e., the context manager), use different - # backend implementations, and achieve the same compilation result. - if "regional_inductor" in pass_names: - pass_names.remove("regional_inductor") - - # inductor lite pass has to be the last pass - move_value_to_end(pass_names, "inductor_lite") + validate_pass_names(pass_names) if ( "autobucketing_reordering" in pass_names diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index 6c0ddc845a..c2f995198f 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -131,6 +131,22 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.deepseek_v3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--activation_checkpoint.mode none", + "--compile.passes inductor_lite", + ], + ], + "deepseek_v3 FSDP+TP+EP+Inductor_lite", + "deepseekv3_fsdp_tp_ep_inductor_lite", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -147,6 +163,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep_flexattention", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes inductor_lite", + ], + ], + "llama3 FSDP+TP+inductor_lite", + "llama3_fsdp_tp_inductor_lite", + ngpu=4, + ), ] return integration_tests_flavors From cdccffaeb566f3d21a5649135f4565262cde1077 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Nov 2025 22:30:11 -0800 Subject: [PATCH 12/15] nit --- torchtitan/experiments/compiler_toolkit/README.md | 2 +- .../experiments/compiler_toolkit/tests/integration_tests.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index fb588803bb..beeb38f272 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -45,7 +45,7 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` -**SimpleFSDP + TP + transformer-block-bucketing** +**SimpleFSDP + TP + transformer-block-bucketing + inductor lite** ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite ``` diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index c2f995198f..8d0844e94c 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -140,6 +140,7 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--parallelism.expert_parallel_degree 4", "--parallelism.expert_tensor_parallel_degree 1", "--activation_checkpoint.mode none", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", "--compile.passes inductor_lite", ], ], From 7052b243b3b23df68ec20684f6061218ae26e1ce Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 10:47:40 -0800 Subject: [PATCH 13/15] patch fake_tensor_cache_enabled for a known issue --- .../experiments/compiler_toolkit/graph_utils.py | 15 ++++++++------- torchtitan/experiments/compiler_toolkit/passes.py | 11 +++++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index dc2ffc7196..a7f7d6aea1 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -332,6 +332,14 @@ def validate_pass_names(pass_names: list[str]) -> None: pass_names[-1] == "inductor_lite" ), "inductor_lite has to be the last pass to apply" + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ @@ -351,13 +359,6 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi pass_names = getattr(job_config.compile, "passes", []) validate_pass_names(pass_names) - if ( - "autobucketing_reordering" in pass_names - and "transformer_block_bucketing" in pass_names - ): - raise ValueError( - "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" - ) compiler_passes = [] for pass_name in pass_names: diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 8fb7cc4847..7e9653cce9 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -108,8 +108,15 @@ def inductor_lite_pass( else: _compiler = get_inductor_lite_bw_compiler(extra_inductor_config) - return _compiler(gm, example_inputs) - + with ( + # TODO Investigate error on MOE model with use_grouped_mm=False. + # For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61 + torch._dynamo.config.patch(fake_tensor_cache_enabled=False), + torch.fx.traceback.preserve_node_meta(), + ): + compiled_fn = _compiler(gm, example_inputs) + + return compiled_fn # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { From d5aa70562d78878e14373ca902c6a51ece92ae36 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 13:10:01 -0800 Subject: [PATCH 14/15] lint --- torchtitan/experiments/compiler_toolkit/passes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 7e9653cce9..0a89ec3c56 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -118,6 +118,7 @@ def inductor_lite_pass( return compiled_fn + # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, From 249106ea4b796cbda08c17dcf4b6c781d8e58381 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 21:32:14 -0800 Subject: [PATCH 15/15] nit --- torchtitan/experiments/compiler_toolkit/README.md | 4 ++-- .../experiments/compiler_toolkit/graph_utils.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 086c6f05e8..21c855d80e 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -24,7 +24,7 @@ NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to **SimpleFSDP + TP + EP + Inductor Lite** ```shell -NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite ``` @@ -47,7 +47,7 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to **SimpleFSDP + TP + transformer-block-bucketing + inductor lite** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite ``` **SimpleFSDP + TP + FlexAttention** diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 34a571e7bb..77076ee79e 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -261,16 +261,16 @@ def compiler( ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") - if end_with_pass(passes, ["cudagraph_pass"]): - # cudagraph pass is always the last pass if it is applied - cg_pass = passes[-1] + if end_with_pass(passes, ["cudagraph_pass", "inductor_lite_pass"]): + # cudagraph pass or inductor lite pass is always the last pass if it is applied + last_pass = passes[-1] - # to identify static input indices, cudagraph passes behaves differently for - # forward and backward pass. so we explicitly pass the info. - _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + # these two passes behave differently for forward and backward pass to identify + # static input indices. so we explicitly pass the info. + _last_pass = functools.partial(last_pass, is_forward=is_forward) # keep the function name for debug log - passes[-1] = functools.wraps(cg_pass)(_cg_pass) + passes[-1] = functools.wraps(last_pass)(_last_pass) for pass_fn in passes: pass_name = ( @@ -375,7 +375,6 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi pass_names = getattr(job_config.compile, "passes", []) validate_pass_names(pass_names) - compiler_passes = [] for pass_name in pass_names: