diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md index ef66a59166..7e112329b9 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/auto_parallel/README.md @@ -4,4 +4,8 @@ requires installing git@github.com:pytorch-labs/autoparallel.git `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` +Use simplefsdp's autobucketing pass: + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` + (or llama3-8b.toml) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2829aa3c55..b69b74faac 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,6 +8,7 @@ import os import time from datetime import timedelta +from functools import partial from typing import Any, Generator, Iterable, Optional import torch @@ -130,32 +131,29 @@ def __init__(self, job_config: JobConfig): # allow configuring inductor comms optimizations from torchtitan commandline if job_config.experimental.enable_simplefsdp_passes: - try: - from torch._inductor.simple_fsdp.bucket import bucket_fsdp_all_gather_concat_on_scheduler_ir - except ImportError: - print("Must use pytorch from unlanded https://github.com/pytorch/pytorch/pull/160282, e.g. torchtitan_conda_prod:5e4101faa448c2ee6b62ddd76ee08e8c") - raise - - # Configs from Ruisi + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, + ) - # set to 0.1 if you want to make bucketing more efficient with mixed dtype collectives - torch._inductor.config.simplefsdp.relax_ratio = 0 torch._inductor.config.allow_buffer_reuse = False - torch._inductor.config.simplefsdp.estimate_ir = False - torch._inductor.config.simplefsdp.estimate_verbose = False - torch._inductor.config.simplefsdp.save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast_"+job_config.model.flavor+".pkl" - # set to True after the first communication estimation results are saved. This would reduce decision making time. - torch._inductor.config.simplefsdp.load_cache = False - torch._inductor.config.simplefsdp.enable_bucket_ir = True - torch._inductor.config.simplefsdp.enable_reorder_ir = True - torch._inductor.config.simplefsdp.simplefsdp_only = False # False for 2d True for 1d - torch._inductor.config.simplefsdp.peak_memory_offset = 0 - torch._inductor.config.simplefsdp.bucketing_type = "auto" + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.save_estimation_path = ( + "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" + ) + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] # Don't use both sets of passes at the same time! torch._inductor.config.bucket_all_gathers_fx = "none" torch._inductor.config.bucket_reduce_scatters_fx = "none" - torch._inductor.config.reorder_for_compute_comm_overlap = False else: torch._inductor.config.bucket_all_gathers_fx = ( job_config.experimental.bucket_all_gathers_fx