Skip to content

Commit 75fb2eb

Browse files
add simplefsdp's autobucketing pass entry (#1658)
as titled, this pr adds entry to simplefsdp's autobucketing pass in autoparallel. original code is in: pytorch/pytorch#160282 The main code for autobucketing pass will be added to autoparallel repo.
1 parent bfa9f7f commit 75fb2eb

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

torchtitan/experiments/auto_parallel/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@ requires installing [email protected]:pytorch-labs/autoparallel.git
44

55
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`
66

7+
Use simplefsdp's autobucketing pass:
8+
9+
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable`
10+
711
(or llama3-8b.toml)

torchtitan/train.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11+
from functools import partial
1112
from typing import Any, Generator, Iterable, Optional
1213

1314
import torch
@@ -130,32 +131,29 @@ def __init__(self, job_config: JobConfig):
130131

131132
# allow configuring inductor comms optimizations from torchtitan commandline
132133
if job_config.experimental.enable_simplefsdp_passes:
133-
try:
134-
from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
135-
except ImportError:
136-
print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c")
137-
raise
138-
139-
# Configs from Ruisi
134+
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
135+
from autoparallel.auto_bucketing import (
136+
simple_fsdp_autobucketing_reordering_pass,
137+
simplefsdp_autobucketing_config,
138+
)
140139

141-
# set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
142-
torch._inductor.config.simplefsdp.relax_ratio = 0
143140
torch._inductor.config.allow_buffer_reuse = False
144-
torch._inductor.config.simplefsdp.estimate_ir = False
145-
torch._inductor.config.simplefsdp.estimate_verbose = False
146-
torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl"
147-
# set to True after the first communication estimation results are saved. This would reduce decision making time.
148-
torch._inductor.config.simplefsdp.load_cache = False
149-
torch._inductor.config.simplefsdp.enable_bucket_ir = True
150-
torch._inductor.config.simplefsdp.enable_reorder_ir = True
151-
torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d
152-
torch._inductor.config.simplefsdp.peak_memory_offset = 0
153-
torch._inductor.config.simplefsdp.bucketing_type = "auto"
141+
torch._inductor.config.reorder_for_peak_memory = False
142+
torch._inductor.config.reorder_for_compute_comm_overlap = True
143+
simplefsdp_autobucketing_config.save_estimation_path = (
144+
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
145+
)
146+
simple_fsdp_autobucketing_reordering_pass = partial(
147+
simple_fsdp_autobucketing_reordering_pass,
148+
configs=simplefsdp_autobucketing_config,
149+
)
150+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
151+
simple_fsdp_autobucketing_reordering_pass
152+
]
154153

155154
# Don't use both sets of passes at the same time!
156155
torch._inductor.config.bucket_all_gathers_fx = "none"
157156
torch._inductor.config.bucket_reduce_scatters_fx = "none"
158-
torch._inductor.config.reorder_for_compute_comm_overlap = False
159157
else:
160158
torch._inductor.config.bucket_all_gathers_fx = (
161159
job_config.experimental.bucket_all_gathers_fx

0 commit comments

Comments
 (0)