diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 620911ce60..21c855d80e 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -22,6 +22,12 @@ NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn ``` +**SimpleFSDP + TP + EP + Inductor Lite** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +``` + + ## llama3 **SimpleFSDP + TP** @@ -39,6 +45,11 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` +**SimpleFSDP + TP + transformer-block-bucketing + inductor lite** +```shell +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite +``` + **SimpleFSDP + TP + FlexAttention** ```shell NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 011bbe402a..e961eaab21 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -23,6 +23,8 @@ CompiledModule, get_compiler_passes_from_config, get_joint_custom_passes_from_config, + GraphBuilderOptions, + is_using_inductor_lite, joint_graph_builder, make_compiler_with_passes, ) @@ -87,13 +89,18 @@ def parallelize_deepseekv3( compiler_passes, dump_folder=job_config.job.dump_folder ) + options = GraphBuilderOptions( + dump_folder=job_config.job.dump_folder, + use_inductor_lite=is_using_inductor_lite(job_config), + ) + # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, - dump_folder=job_config.job.dump_folder, + options=options, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index e097579cc0..77076ee79e 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import dataclasses import functools from pathlib import Path from typing import Any, Callable, List, Optional @@ -24,6 +25,12 @@ from torchtitan.tools.logging import logger +@dataclasses.dataclass(frozen=True) +class GraphBuilderOptions: + dump_folder: str | None = None + use_inductor_lite: bool = False + + def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: # TODO: make the dump rank configurable if not dump_folder or torch.distributed.get_rank() != 0: @@ -89,7 +96,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_passes: Optional[List[Callable]] = None, - dump_folder: str | None = None, + options: GraphBuilderOptions = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -101,7 +108,7 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_passes: list of custom passes to run on the joint graph - dump_folder: Optional folder to dump the graph to + options: Optional configs for graph builder """ assert isinstance(model_args, tuple) for idx, arg in enumerate(model_args): @@ -111,7 +118,7 @@ def joint_graph_builder( ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) + ) = export_joint(model, model_args, model_kwargs, dump_folder=options.dump_folder) # run custom passes on joint-graph before partitioner if joint_custom_passes is not None: @@ -120,7 +127,9 @@ def joint_graph_builder( joint_with_descriptors.graph_module ) - with tracing(tracing_context): + with tracing(tracing_context), torch._functorch.config.patch( + selective_decompose=options.use_inductor_lite + ): fn = aot_compile_joint_with_descriptors( joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler ) @@ -235,22 +244,33 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES + if end_with_pass(passes, ["inductor_lite_pass"]): + # inductor lite pass is always the last pass if it is applied since it + # behaves differently for forward and backwawrd. so we explicitly pass + # the info. For example, different methods are used to identify static input + # indices. + last_pass = passes[-1] + _last_pass = functools.partial(last_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(last_pass)(_last_pass) + logger.debug(f"{name} before compiler:") logger.debug( gm.print_readable(print_output=False, include_stride=True, include_device=True) ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") - if end_with_pass(passes, ["cudagraph_pass"]): - # cudagraph pass is always the last pass if it is applied - cg_pass = passes[-1] + if end_with_pass(passes, ["cudagraph_pass", "inductor_lite_pass"]): + # cudagraph pass or inductor lite pass is always the last pass if it is applied + last_pass = passes[-1] - # to identify static input indices, cudagraph passes behaves differently for - # forward and backward pass. so we explicitly pass the info. - _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + # these two passes behave differently for forward and backward pass to identify + # static input indices. so we explicitly pass the info. + _last_pass = functools.partial(last_pass, is_forward=is_forward) # keep the function name for debug log - passes[-1] = functools.wraps(cg_pass)(_cg_pass) + passes[-1] = functools.wraps(last_pass)(_last_pass) for pass_fn in passes: pass_name = ( @@ -261,11 +281,16 @@ def compiler( logger.info(f"Applying pass: {pass_name}") gm = pass_fn(gm, example_inputs) - logger.debug(f"{name} after compiler:") - logger.debug( - gm.print_readable(print_output=False, include_stride=True, include_device=True) - ) - _dump_gm(dump_folder, gm, f"{name}_after_compiler") + if not end_with_pass(passes, ["inductor_lite_pass"]): + # inductor lite mode returns a CompiledFxGraph which does not support print_readable. + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") + return gm @@ -306,7 +331,20 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: def validate_pass_names(pass_names: list[str]) -> None: - if "cudagraph" in pass_names: + if "inductor_lite" in pass_names and "cudagraph" in pass_names: + raise ValueError("Cannot apply inductor_lite and cudagraph at the same time!") + elif "inductor_lite" in pass_names: + # inductor lite supports regional_inductor by default. They share the same + # user-facing frontend API (i.e., the context manager), use different + # backend implementations, and achieve the same compilation result. + assert "regional_inductor" not in pass_names, ( + "inductor_lite uses regional_inductor by default. please use one " + "pass at a time." + ) + assert ( + pass_names[-1] == "inductor_lite" + ), "inductor_lite has to be the last pass to apply" + elif "cudagraph" in pass_names: assert ( pass_names[-1] == "cudagraph" ), "cudagraph has to be the last pass to apply" @@ -403,3 +441,8 @@ def get_joint_custom_passes_from_config( ) return joint_custom_passes + + +def is_using_inductor_lite(job_config: JobConfig) -> bool: + pass_names = getattr(job_config.compile, "passes", []) + return "inductor_lite" in pass_names diff --git a/torchtitan/experiments/compiler_toolkit/inductor_lite.py b/torchtitan/experiments/compiler_toolkit/inductor_lite.py new file mode 100644 index 0000000000..a25c01b08f --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/inductor_lite.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Inductor lite pass for the compiler toolkit. + +This module provides inductor lite pass that can be applied to graph modules +during compilation. +""" +from typing import Optional + +import torch +from torchtitan.tools.logging import logger + + +def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None): + from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner + + context = torch._guards.TracingContext.try_get() + + if not context or not context.fw_metadata: + logger.warn("No context or fw_metadata available") + static_input_idxs = () + else: + static_input_idxs = context.fw_metadata.static_input_indices + + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + with torch._inductor.config.patch(inductor_config): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=static_input_idxs, + is_backward=False, + ) + return compiled_fn + + return fw_compiler + + +def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None): + from torch._inductor import lite_mode_options + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.utils import count_tangents + + inductor_config = lite_mode_options + if extra_config: + inductor_config.update(extra_config) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple): + fixed = count_tangents(gm) + + with torch._inductor.config.patch(inductor_config): + compiled_fn = compile_fx_inner( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + is_backward=True, + ) + return compiled_fn + + return bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 68fa7443f4..d56db32f4f 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -22,6 +22,8 @@ CompiledModule, get_compiler_passes_from_config, get_joint_custom_passes_from_config, + GraphBuilderOptions, + is_using_inductor_lite, joint_graph_builder, make_compiler_with_passes, ) @@ -74,13 +76,18 @@ def parallelize_llama( compiler_passes, dump_folder=job_config.job.dump_folder ) + options = GraphBuilderOptions( + dump_folder=job_config.job.dump_folder, + use_inductor_lite=is_using_inductor_lite(job_config), + ) + # Create custom joint_graph_builder with llama-specific compilers llama_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, - dump_folder=job_config.job.dump_folder, + options=options, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 5657eb2b2b..a306600fb2 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,7 +11,7 @@ during compilation. Passes can be selected and configured via job config. """ -from typing import Any, Sequence +from typing import Any, Callable, Sequence import torch from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing @@ -21,6 +21,10 @@ CUDAGraphWrapper, get_static_input_indices, ) +from torchtitan.experiments.compiler_toolkit.inductor_lite import ( + get_inductor_lite_bw_compiler, + get_inductor_lite_fw_compiler, +) from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) @@ -106,10 +110,41 @@ def fsdp_reshard_after_fwd_pass( return gm +def inductor_lite_pass( + gm: torch.fx.GraphModule, example_inputs, is_forward: bool +) -> Callable: + """ + Apply inductor lite mode. + + This pass takes a gm and generates a callable (not gm) using inductor. The lite + mode falls back for all ops except explicitly user-annotated ops under + regional compile. + """ + # TODO: fix inductor size assertion for all_reduce + # https://github.com/pytorch/pytorch/issues/167430 + extra_inductor_config = {"size_asserts": False} + + if is_forward: + _compiler = get_inductor_lite_fw_compiler(extra_inductor_config) + else: + _compiler = get_inductor_lite_bw_compiler(extra_inductor_config) + + with ( + # TODO Investigate error on MOE model with use_grouped_mm=False. + # For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61 + torch._dynamo.config.patch(fake_tensor_cache_enabled=False), + torch.fx.traceback.preserve_node_meta(), + ): + compiled_fn = _compiler(gm, example_inputs) + + return compiled_fn + + # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, "regional_inductor": regional_inductor_pass, + "inductor_lite": inductor_lite_pass, "cudagraph": cudagraph_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index f01a1c4380..871e4fe387 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -130,6 +130,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_manualbucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes inductor_lite", + ], + ], + "llama3 FSDP+TP+inductor_lite", + "llama3_fsdp_tp_inductor_lite", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ @@ -146,6 +160,23 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.deepseek_v3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--activation_checkpoint.mode none", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes inductor_lite", + ], + ], + "deepseek_v3 FSDP+TP+EP+Inductor_lite", + "deepseekv3_fsdp_tp_ep_inductor_lite", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -162,6 +193,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "deepseekv3_fsdp_tp_ep_flexattention", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes inductor_lite", + ], + ], + "llama3 FSDP+TP+inductor_lite", + "llama3_fsdp_tp_inductor_lite", + ngpu=4, + ), ] return integration_tests_flavors