Skip to content

Commit 2071dfb

Browse files
committed
Update on "compiled RMSNorm"
On Llama3 8B model, no AC `compiled_rmsnorm` is ~9% faster than `rmsnorm`, but ~2% slower than `fused_rmsnorm`. Please see below for details. rmsnorm <img width="757" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/79645518-e38b-4ddb-b01d-b0c93ec27dd4"> compiled_rmsnorm <img width="754" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/c457b388-793f-452b-9bce-17bc1823df66"> fused_rmsnorm <img width="753" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/ea1db7ad-5887-4efa-9788-e708e4b40428"> [ghstack-poisoned]
1 parent 4c33e52 commit 2071dfb

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

estimation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig):
5757
)
5858
job_config.model.norm_type = "rmsnorm"
5959

60+
if job_config.model.norm_type == "compiled_rmsnorm":
61+
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
62+
job_config.model.norm_type = "rmsnorm"
63+
6064
if job_config.training.compile:
61-
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
65+
logger.info("Compile mode is not supported yet. Switching to eager mode.")
6266
job_config.training.compile = False
6367

6468
parallel_dims = ParallelDims(

test_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def build_test_list():
266266
OverrideDefinitions(
267267
[
268268
[
269-
"--memory_estimation.enabled",
269+
"--memory_estimation.enabled --model.norm_type rmsnorm",
270270
]
271271
],
272272
"FSDP2 Memory Tracking and Estimation",

0 commit comments

Comments
 (0)