|
8 | 8 | import os |
9 | 9 | import time |
10 | 10 | from datetime import timedelta |
| 11 | +from functools import partial |
11 | 12 | from typing import Any, Generator, Iterable, Optional |
12 | 13 |
|
13 | 14 | import torch |
@@ -125,32 +126,29 @@ def __init__(self, job_config: JobConfig): |
125 | 126 |
|
126 | 127 | # allow configuring inductor comms optimizations from torchtitan commandline |
127 | 128 | if job_config.experimental.enable_simplefsdp_passes: |
128 | | - try: |
129 | | - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir |
130 | | - except ImportError: |
131 | | - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") |
132 | | - raise |
133 | | - |
134 | | - # Configs from Ruisi |
| 129 | + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) |
| 130 | + from autoparallel.auto_bucketing import ( |
| 131 | + simple_fsdp_autobucketing_reordering_pass, |
| 132 | + simplefsdp_autobucketing_config, |
| 133 | + ) |
135 | 134 |
|
136 | | - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives |
137 | | - torch._inductor.config.simplefsdp.relax_ratio = 0 |
138 | 135 | torch._inductor.config.allow_buffer_reuse = False |
139 | | - torch._inductor.config.simplefsdp.estimate_ir = False |
140 | | - torch._inductor.config.simplefsdp.estimate_verbose = False |
141 | | - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" |
142 | | - # set to True after the first communication estimation results are saved. This would reduce decision making time. |
143 | | - torch._inductor.config.simplefsdp.load_cache = False |
144 | | - torch._inductor.config.simplefsdp.enable_bucket_ir = True |
145 | | - torch._inductor.config.simplefsdp.enable_reorder_ir = True |
146 | | - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d |
147 | | - torch._inductor.config.simplefsdp.peak_memory_offset = 0 |
148 | | - torch._inductor.config.simplefsdp.bucketing_type = "auto" |
| 136 | + torch._inductor.config.reorder_for_peak_memory = False |
| 137 | + torch._inductor.config.reorder_for_compute_comm_overlap = True |
| 138 | + simplefsdp_autobucketing_config.save_estimation_path = ( |
| 139 | + "/tmp/estimation_mast.pkl" |
| 140 | + ) |
| 141 | + simple_fsdp_autobucketing_reordering_pass = partial( |
| 142 | + simple_fsdp_autobucketing_reordering_pass, |
| 143 | + configs=simplefsdp_autobucketing_config, |
| 144 | + ) |
| 145 | + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ |
| 146 | + simple_fsdp_autobucketing_reordering_pass |
| 147 | + ] |
149 | 148 |
|
150 | 149 | # Don't use both sets of passes at the same time! |
151 | 150 | torch._inductor.config.bucket_all_gathers_fx = "none" |
152 | 151 | torch._inductor.config.bucket_reduce_scatters_fx = "none" |
153 | | - torch._inductor.config.reorder_for_compute_comm_overlap = False |
154 | 152 | else: |
155 | 153 | torch._inductor.config.bucket_all_gathers_fx = ( |
156 | 154 | job_config.experimental.bucket_all_gathers_fx |
|
0 commit comments