|
33 | 33 | maybe_enable_memory_snapshot, |
34 | 34 | maybe_enable_profiling, |
35 | 35 | ) |
| 36 | +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing |
36 | 37 |
|
37 | 38 |
|
38 | 39 | class Trainer(torch.distributed.checkpoint.stateful.Stateful): |
@@ -122,46 +123,7 @@ def __init__(self, job_config: JobConfig): |
122 | 123 | torch._inductor.config.allow_buffer_reuse = False |
123 | 124 |
|
124 | 125 | # allow configuring inductor comms optimizations from torchtitan commandline |
125 | | - if job_config.experimental.enable_simplefsdp_passes: |
126 | | - # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) |
127 | | - from autoparallel.auto_bucketing import ( |
128 | | - simple_fsdp_autobucketing_reordering_pass, |
129 | | - simplefsdp_autobucketing_config, |
130 | | - ) |
131 | | - |
132 | | - torch._inductor.config.allow_buffer_reuse = False |
133 | | - torch._inductor.config.reorder_for_peak_memory = False |
134 | | - torch._inductor.config.reorder_for_compute_comm_overlap = True |
135 | | - simplefsdp_autobucketing_config.save_estimation_path = ( |
136 | | - "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" |
137 | | - ) |
138 | | - simple_fsdp_autobucketing_reordering_pass = partial( |
139 | | - simple_fsdp_autobucketing_reordering_pass, |
140 | | - configs=simplefsdp_autobucketing_config, |
141 | | - ) |
142 | | - torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ |
143 | | - simple_fsdp_autobucketing_reordering_pass |
144 | | - ] |
145 | | - |
146 | | - # Don't use both sets of passes at the same time! |
147 | | - torch._inductor.config.bucket_all_gathers_fx = "none" |
148 | | - torch._inductor.config.bucket_reduce_scatters_fx = "none" |
149 | | - else: |
150 | | - torch._inductor.config.bucket_all_gathers_fx = ( |
151 | | - job_config.experimental.bucket_all_gathers_fx |
152 | | - ) |
153 | | - torch._inductor.config.bucket_reduce_scatters_fx = ( |
154 | | - job_config.experimental.bucket_reduce_scatters_fx |
155 | | - ) |
156 | | - torch._inductor.config.reorder_for_compute_comm_overlap = ( |
157 | | - job_config.experimental.reorder_for_compute_comm_overlap |
158 | | - ) |
159 | | - torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( |
160 | | - job_config.experimental.reorder_for_compute_comm_overlap_passes |
161 | | - ) |
162 | | - torch._inductor.config.reorder_prefetch_limit = ( |
163 | | - job_config.experimental.reorder_prefetch_limit |
164 | | - ) |
| 126 | + configure_inductor_for_autobucketing(job_config.experimental.comms_bucket_reorder_strategy) |
165 | 127 |
|
166 | 128 | # Set random seed, and maybe enable deterministic mode |
167 | 129 | # (mainly for debugging, expect perf loss). |
|
0 commit comments