Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 2 additions & 28 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,37 +857,11 @@ class Experimental:
needs to ensure that the path can be imported.
"""

# "none", "all", "only_fsdp"
bucket_all_gathers_fx: str = "none"

# "none", "all"
bucket_reduce_scatters_fx: str = "none"

reorder_for_compute_comm_overlap: bool = False
"""
Whether to enable inductor comm reordering passes
"""

reorder_for_compute_comm_overlap_passes: list[str] = field(
default_factory=lambda: [
"sink_waits_iterative",
"reorder_communication_preserving_peak_memory",
]
)
"""
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
if reorder_for_compute_comm_overlap is enabled.
"""

reorder_prefetch_limit: int | None = None
"""
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
pass is enabled. default of None means unlimited
"""
# "aten" (default), "inductor", "none"
comms_bucket_reorder_strategy: str = "aten"

autop_force_bf16: bool = False

enable_simplefsdp_passes: bool = False

@dataclass
class Validation:
Expand Down
42 changes: 2 additions & 40 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
maybe_enable_memory_snapshot,
maybe_enable_profiling,
)
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
Expand Down Expand Up @@ -122,46 +123,7 @@ def __init__(self, job_config: JobConfig):
torch._inductor.config.allow_buffer_reuse = False

# allow configuring inductor comms optimizations from torchtitan commandline
if job_config.experimental.enable_simplefsdp_passes:
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
from autoparallel.auto_bucketing import (
simple_fsdp_autobucketing_reordering_pass,
simplefsdp_autobucketing_config,
)

torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
simplefsdp_autobucketing_config.save_estimation_path = (
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
)
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config,
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]

# Don't use both sets of passes at the same time!
torch._inductor.config.bucket_all_gathers_fx = "none"
torch._inductor.config.bucket_reduce_scatters_fx = "none"
else:
torch._inductor.config.bucket_all_gathers_fx = (
job_config.experimental.bucket_all_gathers_fx
)
torch._inductor.config.bucket_reduce_scatters_fx = (
job_config.experimental.bucket_reduce_scatters_fx
)
torch._inductor.config.reorder_for_compute_comm_overlap = (
job_config.experimental.reorder_for_compute_comm_overlap
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
job_config.experimental.reorder_for_compute_comm_overlap_passes
)
torch._inductor.config.reorder_prefetch_limit = (
job_config.experimental.reorder_prefetch_limit
)
configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy)

# Set random seed, and maybe enable deterministic mode
# (mainly for debugging, expect perf loss).
Expand Down