Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions torchtitan/experiments/compiler_toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to
NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn
```

**SimpleFSDP + TP + EP + Inductor Lite**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite
```


## llama3

**SimpleFSDP + TP**
Expand All @@ -39,6 +45,11 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to
NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing
```

**SimpleFSDP + TP + transformer-block-bucketing + inductor lite**
```shell
NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes inductor_lite
```

**SimpleFSDP + TP + FlexAttention**
```shell
NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
GraphBuilderOptions,
is_using_inductor_lite,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -87,13 +89,18 @@ def parallelize_deepseekv3(
compiler_passes, dump_folder=job_config.job.dump_folder
)

options = GraphBuilderOptions(
dump_folder=job_config.job.dump_folder,
use_inductor_lite=is_using_inductor_lite(job_config),
)

# Create custom joint_graph_builder with deepseekv3-specific compilers
deepseekv3_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
options=options,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
77 changes: 60 additions & 17 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import dataclasses
import functools
from pathlib import Path
from typing import Any, Callable, List, Optional
Expand All @@ -24,6 +25,12 @@
from torchtitan.tools.logging import logger


@dataclasses.dataclass(frozen=True)
class GraphBuilderOptions:
dump_folder: str | None = None
use_inductor_lite: bool = False


def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
# TODO: make the dump rank configurable
if not dump_folder or torch.distributed.get_rank() != 0:
Expand Down Expand Up @@ -89,7 +96,7 @@ def joint_graph_builder(
fw_compiler: Optional[Callable] = None,
bw_compiler: Optional[Callable] = None,
joint_custom_passes: Optional[List[Callable]] = None,
dump_folder: str | None = None,
options: GraphBuilderOptions = None,
):
"""
Build a joint forward-backward graph for the model with optional custom compilers.
Expand All @@ -101,7 +108,7 @@ def joint_graph_builder(
fw_compiler: Optional custom forward compiler function
bw_compiler: Optional custom backward compiler function
joint_custom_passes: list of custom passes to run on the joint graph
dump_folder: Optional folder to dump the graph to
options: Optional configs for graph builder
"""
assert isinstance(model_args, tuple)
for idx, arg in enumerate(model_args):
Expand All @@ -111,7 +118,7 @@ def joint_graph_builder(
(
joint_with_descriptors,
tracing_context,
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
) = export_joint(model, model_args, model_kwargs, dump_folder=options.dump_folder)

# run custom passes on joint-graph before partitioner
if joint_custom_passes is not None:
Expand All @@ -120,7 +127,9 @@ def joint_graph_builder(
joint_with_descriptors.graph_module
)

with tracing(tracing_context):
with tracing(tracing_context), torch._functorch.config.patch(
selective_decompose=options.use_inductor_lite
):
fn = aot_compile_joint_with_descriptors(
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
)
Expand Down Expand Up @@ -235,22 +244,33 @@ def compiler(
if passes is None:
passes = DEFAULT_COMPILER_PASSES

if end_with_pass(passes, ["inductor_lite_pass"]):
# inductor lite pass is always the last pass if it is applied since it
# behaves differently for forward and backwawrd. so we explicitly pass
# the info. For example, different methods are used to identify static input
# indices.
last_pass = passes[-1]
_last_pass = functools.partial(last_pass, is_forward=is_forward)

# keep the function name for debug log
passes[-1] = functools.wraps(last_pass)(_last_pass)

logger.debug(f"{name} before compiler:")
logger.debug(
gm.print_readable(print_output=False, include_stride=True, include_device=True)
)
_dump_gm(dump_folder, gm, f"{name}_before_compiler")

if end_with_pass(passes, ["cudagraph_pass"]):
# cudagraph pass is always the last pass if it is applied
cg_pass = passes[-1]
if end_with_pass(passes, ["cudagraph_pass", "inductor_lite_pass"]):
# cudagraph pass or inductor lite pass is always the last pass if it is applied
last_pass = passes[-1]

# to identify static input indices, cudagraph passes behaves differently for
# forward and backward pass. so we explicitly pass the info.
_cg_pass = functools.partial(cg_pass, is_forward=is_forward)
# these two passes behave differently for forward and backward pass to identify
# static input indices. so we explicitly pass the info.
_last_pass = functools.partial(last_pass, is_forward=is_forward)

# keep the function name for debug log
passes[-1] = functools.wraps(cg_pass)(_cg_pass)
passes[-1] = functools.wraps(last_pass)(_last_pass)

for pass_fn in passes:
pass_name = (
Expand All @@ -261,11 +281,16 @@ def compiler(
logger.info(f"Applying pass: {pass_name}")
gm = pass_fn(gm, example_inputs)

logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(print_output=False, include_stride=True, include_device=True)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
if not end_with_pass(passes, ["inductor_lite_pass"]):
# inductor lite mode returns a CompiledFxGraph which does not support print_readable.
logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")

return gm


Expand Down Expand Up @@ -306,7 +331,20 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:


def validate_pass_names(pass_names: list[str]) -> None:
if "cudagraph" in pass_names:
if "inductor_lite" in pass_names and "cudagraph" in pass_names:
raise ValueError("Cannot apply inductor_lite and cudagraph at the same time!")
elif "inductor_lite" in pass_names:
# inductor lite supports regional_inductor by default. They share the same
# user-facing frontend API (i.e., the context manager), use different
# backend implementations, and achieve the same compilation result.
assert "regional_inductor" not in pass_names, (
"inductor_lite uses regional_inductor by default. please use one "
"pass at a time."
)
assert (
pass_names[-1] == "inductor_lite"
), "inductor_lite has to be the last pass to apply"
elif "cudagraph" in pass_names:
assert (
pass_names[-1] == "cudagraph"
), "cudagraph has to be the last pass to apply"
Expand Down Expand Up @@ -403,3 +441,8 @@ def get_joint_custom_passes_from_config(
)

return joint_custom_passes


def is_using_inductor_lite(job_config: JobConfig) -> bool:
pass_names = getattr(job_config.compile, "passes", [])
return "inductor_lite" in pass_names
69 changes: 69 additions & 0 deletions torchtitan/experiments/compiler_toolkit/inductor_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Inductor lite pass for the compiler toolkit.

This module provides inductor lite pass that can be applied to graph modules
during compilation.
"""
from typing import Optional

import torch
from torchtitan.tools.logging import logger


def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner

context = torch._guards.TracingContext.try_get()

if not context or not context.fw_metadata:
logger.warn("No context or fw_metadata available")
static_input_idxs = ()
else:
static_input_idxs = context.fw_metadata.static_input_indices

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=static_input_idxs,
is_backward=False,
)
return compiled_fn

return fw_compiler


def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.utils import count_tangents

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
fixed = count_tangents(gm)

with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
is_backward=True,
)
return compiled_fn

return bw_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
CompiledModule,
get_compiler_passes_from_config,
get_joint_custom_passes_from_config,
GraphBuilderOptions,
is_using_inductor_lite,
joint_graph_builder,
make_compiler_with_passes,
)
Expand Down Expand Up @@ -74,13 +76,18 @@ def parallelize_llama(
compiler_passes, dump_folder=job_config.job.dump_folder
)

options = GraphBuilderOptions(
dump_folder=job_config.job.dump_folder,
use_inductor_lite=is_using_inductor_lite(job_config),
)

# Create custom joint_graph_builder with llama-specific compilers
llama_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
options=options,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
37 changes: 36 additions & 1 deletion torchtitan/experiments/compiler_toolkit/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
during compilation. Passes can be selected and configured via job config.
"""

from typing import Any, Sequence
from typing import Any, Callable, Sequence

import torch
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
Expand All @@ -21,6 +21,10 @@
CUDAGraphWrapper,
get_static_input_indices,
)
from torchtitan.experiments.compiler_toolkit.inductor_lite import (
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
)
from torchtitan.experiments.simple_fsdp.reshard_after_forward import (
annotate_fsdp_all_gather,
)
Expand Down Expand Up @@ -106,10 +110,41 @@ def fsdp_reshard_after_fwd_pass(
return gm


def inductor_lite_pass(
gm: torch.fx.GraphModule, example_inputs, is_forward: bool
) -> Callable:
"""
Apply inductor lite mode.

This pass takes a gm and generates a callable (not gm) using inductor. The lite
mode falls back for all ops except explicitly user-annotated ops under
regional compile.
"""
# TODO: fix inductor size assertion for all_reduce
# https://github.com/pytorch/pytorch/issues/167430
extra_inductor_config = {"size_asserts": False}

if is_forward:
_compiler = get_inductor_lite_fw_compiler(extra_inductor_config)
else:
_compiler = get_inductor_lite_bw_compiler(extra_inductor_config)

with (
# TODO Investigate error on MOE model with use_grouped_mm=False.
# For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61
torch._dynamo.config.patch(fake_tensor_cache_enabled=False),
torch.fx.traceback.preserve_node_meta(),
):
compiled_fn = _compiler(gm, example_inputs)

return compiled_fn


# Registry mapping pass names to pass functions
AVAILABLE_COMPILER_PASSES = {
"autobucketing_reordering": autobucketing_reordering_pass,
"transformer_block_bucketing": transformer_block_bucketing_reordering_pass,
"regional_inductor": regional_inductor_pass,
"inductor_lite": inductor_lite_pass,
"cudagraph": cudagraph_pass,
}
Loading
Loading