From b90c8dafcd843ceed15484483d617e00dd73ae35 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 9 Jul 2024 23:15:33 -0700 Subject: [PATCH 1/3] some compile-related improvements [ghstack-poisoned] --- test_runner.py | 9 +++++++++ torchtitan/parallelisms/parallelize_llama.py | 16 ++++------------ train_configs/llama3_8b.toml | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test_runner.py b/test_runner.py index cba63544aa..1331eae118 100755 --- a/test_runner.py +++ b/test_runner.py @@ -153,6 +153,15 @@ def build_test_list(): "1D compile", "1d_compile", ), + OverrideDefinitions( + [ + [ + "--training.compile --model.norm_type=rmsnorm --selective_ac_option=op", + ], + ], + "1D compile with selective op AC", + "1d_compile_sac_op", + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 6cafa4abe4..aca77df1b7 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -432,22 +432,14 @@ def apply_compile(model, job_config: JobConfig): "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." ) + # NOTE(anijain): enable the following flag to accelarate compilation + torch._dynamo.config.inline_inbuilt_nn_modules = True + for layer_id, transformer_block in model.layers.named_children(): # turn on per-transformer block compile after AC wrapping and before FSDP - # TODO: dynamic shape have some issues so we turn it off for now. - # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate - # compile time. - # torch._dynamo.config.inline_inbuilt_nn_modules = True - transformer_block = torch.compile(transformer_block, dynamic=False) + transformer_block = torch.compile(transformer_block, fullgraph=True) model.layers.register_module(layer_id, transformer_block) - ac_config = job_config.activation_checkpoint - if ac_config.mode == "selective" and ac_config.selective_ac_option == "op": - # some temp flags for torch.compile enablement + SAC - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiled each TransformerBlock with torch.compile") return model diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 2fb89004ad..05c2ec83b6 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -34,7 +34,7 @@ steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 fp8_linear = false -compile = false +compile = true dataset = "c4" [experimental] From 1a5f0b3d5f2922b5629284ed659a9dd4c987313f Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 9 Jul 2024 23:17:01 -0700 Subject: [PATCH 2/3] Update on "some compile-related improvements" [ghstack-poisoned] --- train_configs/llama3_8b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 05c2ec83b6..2fb89004ad 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -34,7 +34,7 @@ steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 fp8_linear = false -compile = true +compile = false dataset = "c4" [experimental] From dc4193618446b895d990f39628025ddbb5fbc7de Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 10 Jul 2024 00:02:27 -0700 Subject: [PATCH 3/3] Update on "some compile-related improvements" [ghstack-poisoned] --- test_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_runner.py b/test_runner.py index 1331eae118..ec1d791f08 100755 --- a/test_runner.py +++ b/test_runner.py @@ -156,7 +156,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile --model.norm_type=rmsnorm --selective_ac_option=op", + "--training.compile --model.norm_type=rmsnorm --activation_checkpoint.selective_ac_option=op", ], ], "1D compile with selective op AC",