Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchtitan/experiments/auto_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ requires installing [email protected]:pytorch-labs/autoparallel.git

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`

Use simplefsdp's autobucketing pass:

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable`

(or llama3-8b.toml)
38 changes: 18 additions & 20 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from datetime import timedelta
from functools import partial
from typing import Any, Generator, Iterable, Optional

import torch
Expand Down Expand Up @@ -130,32 +131,29 @@ def __init__(self, job_config: JobConfig):

# allow configuring inductor comms optimizations from torchtitan commandline
if job_config.experimental.enable_simplefsdp_passes:
try:
from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir
except ImportError:
print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c")
raise

# Configs from Ruisi
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
from autoparallel.auto_bucketing import (
simple_fsdp_autobucketing_reordering_pass,
simplefsdp_autobucketing_config,
)

# set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives
torch._inductor.config.simplefsdp.relax_ratio = 0
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.simplefsdp.estimate_ir = False
torch._inductor.config.simplefsdp.estimate_verbose = False
torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl"
# set to True after the first communication estimation results are saved. This would reduce decision making time.
torch._inductor.config.simplefsdp.load_cache = False
torch._inductor.config.simplefsdp.enable_bucket_ir = True
torch._inductor.config.simplefsdp.enable_reorder_ir = True
torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d
torch._inductor.config.simplefsdp.peak_memory_offset = 0
torch._inductor.config.simplefsdp.bucketing_type = "auto"
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
simplefsdp_autobucketing_config.save_estimation_path = (
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
)
simple_fsdp_autobucketing_reordering_pass = partial(
simple_fsdp_autobucketing_reordering_pass,
configs=simplefsdp_autobucketing_config,
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
simple_fsdp_autobucketing_reordering_pass
]

# Don't use both sets of passes at the same time!
torch._inductor.config.bucket_all_gathers_fx = "none"
torch._inductor.config.bucket_reduce_scatters_fx = "none"
torch._inductor.config.reorder_for_compute_comm_overlap = False
else:
torch._inductor.config.bucket_all_gathers_fx = (
job_config.experimental.bucket_all_gathers_fx
Expand Down
Loading