Skip to content

Commit c6e25bd

Browse files
authored
Whc/knobs (#1994)
needs to merge in lock step with meta-pytorch/autoparallel#233
1 parent 9dc0bd8 commit c6e25bd

File tree

2 files changed

+4
-68
lines changed

2 files changed

+4
-68
lines changed

torchtitan/config/job_config.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -857,37 +857,11 @@ class Experimental:
857857
needs to ensure that the path can be imported.
858858
"""
859859

860-
# "none", "all", "only_fsdp"
861-
bucket_all_gathers_fx: str = "none"
862-
863-
# "none", "all"
864-
bucket_reduce_scatters_fx: str = "none"
865-
866-
reorder_for_compute_comm_overlap: bool = False
867-
"""
868-
Whether to enable inductor comm reordering passes
869-
"""
870-
871-
reorder_for_compute_comm_overlap_passes: list[str] = field(
872-
default_factory=lambda: [
873-
"sink_waits_iterative",
874-
"reorder_communication_preserving_peak_memory",
875-
]
876-
)
877-
"""
878-
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
879-
if reorder_for_compute_comm_overlap is enabled.
880-
"""
881-
882-
reorder_prefetch_limit: int | None = None
883-
"""
884-
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
885-
pass is enabled. default of None means unlimited
886-
"""
860+
# "aten" (default), "inductor", "none"
861+
comms_bucket_reorder_strategy: str = "aten"
887862

888863
autop_force_bf16: bool = False
889864

890-
enable_simplefsdp_passes: bool = False
891865

892866
@dataclass
893867
class Validation:

torchtitan/train.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
maybe_enable_memory_snapshot,
3434
maybe_enable_profiling,
3535
)
36+
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing
3637

3738

3839
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
@@ -122,46 +123,7 @@ def __init__(self, job_config: JobConfig):
122123
torch._inductor.config.allow_buffer_reuse = False
123124

124125
# allow configuring inductor comms optimizations from torchtitan commandline
125-
if job_config.experimental.enable_simplefsdp_passes:
126-
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
127-
from autoparallel.auto_bucketing import (
128-
simple_fsdp_autobucketing_reordering_pass,
129-
simplefsdp_autobucketing_config,
130-
)
131-
132-
torch._inductor.config.allow_buffer_reuse = False
133-
torch._inductor.config.reorder_for_peak_memory = False
134-
torch._inductor.config.reorder_for_compute_comm_overlap = True
135-
simplefsdp_autobucketing_config.save_estimation_path = (
136-
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
137-
)
138-
simple_fsdp_autobucketing_reordering_pass = partial(
139-
simple_fsdp_autobucketing_reordering_pass,
140-
configs=simplefsdp_autobucketing_config,
141-
)
142-
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
143-
simple_fsdp_autobucketing_reordering_pass
144-
]
145-
146-
# Don't use both sets of passes at the same time!
147-
torch._inductor.config.bucket_all_gathers_fx = "none"
148-
torch._inductor.config.bucket_reduce_scatters_fx = "none"
149-
else:
150-
torch._inductor.config.bucket_all_gathers_fx = (
151-
job_config.experimental.bucket_all_gathers_fx
152-
)
153-
torch._inductor.config.bucket_reduce_scatters_fx = (
154-
job_config.experimental.bucket_reduce_scatters_fx
155-
)
156-
torch._inductor.config.reorder_for_compute_comm_overlap = (
157-
job_config.experimental.reorder_for_compute_comm_overlap
158-
)
159-
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
160-
job_config.experimental.reorder_for_compute_comm_overlap_passes
161-
)
162-
torch._inductor.config.reorder_prefetch_limit = (
163-
job_config.experimental.reorder_prefetch_limit
164-
)
126+
configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy)
165127

166128
# Set random seed, and maybe enable deterministic mode
167129
# (mainly for debugging, expect perf loss).

0 commit comments

Comments
 (0)