-
Notifications
You must be signed in to change notification settings - Fork 610
Autoparallel as an experiment into main #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3ccd12c
e6d2caf
68476b3
9ee9f75
f6e4099
4d7ee8a
b801d0b
b099cf9
b3587d9
42c2c07
d93845e
60f5f11
6c782eb
4712163
3f04d22
8e50870
91c5639
1233902
4f8677b
714cc5b
45647b3
bfa9f7f
75fb2eb
8769396
db22479
87ef4e0
9dc0bd8
c6e25bd
26410e8
e6ea814
7abede8
d2e76b7
472b4ad
ac0def9
a24ef07
da611e4
6cc8caa
d54a6d4
acd9588
2b1fb92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -865,6 +865,11 @@ class Experimental: | |
| needs to ensure that the path can be imported. | ||
| """ | ||
|
|
||
| # "aten" (default), "inductor", "none" | ||
| comms_bucket_reorder_strategy: str = "aten" | ||
|
|
||
| autop_force_bf16: bool = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way for an experiment to add a config knob without polluting the top-level file?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| @dataclass | ||
| class Validation: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| ## Auto Parallel | ||
|
|
||
| requires installing [email protected]:pytorch-labs/autoparallel.git | ||
|
|
||
| `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` | ||
|
|
||
| Use simplefsdp's autobucketing pass: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` | ||
|
|
||
| (or llama3-8b.toml) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| import copy | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.distributed.pipeline_parallel import pipeline_llm | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
|
|
||
| from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model | ||
| from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs | ||
| from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( | ||
| DeepSeekV3StateDictAdapter, | ||
| ) | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .parallelize_deepseekv3 import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| model_args = copy.deepcopy(deepseekv3_args) | ||
|
|
||
| default_args = DeepSeekV3ModelArgs() | ||
| for config, args in model_args.items(): | ||
| if "flex_attn" in config: | ||
| continue | ||
|
|
||
| use_flex_attn = (default_args.use_flex_attn,) | ||
| attn_mask_type = (default_args.attn_mask_type,) | ||
|
|
||
| return TrainSpec( | ||
| model_cls=DeepSeekV3Model, | ||
| model_args=model_args, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llm, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_text_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| state_dict_adapter=DeepSeekV3StateDictAdapter, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one might be worth landing in main. it's purely inductor-specific and used by simple-fsdp as well.
cc @ruisizhang123 @fegin @tianyu-l
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is shared by 2 different models/experiments, I think it is okay to add it to the core job_config. This will be used by full dtensor and compiler toolki iiuc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's inductor specific bucketing. Maybe we should upstream inductor bucketing code to pytorch?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if users run core FSDP2 with this option? Is it a no op?
I think we encounter a tricky case where multiple experiments would share config that's not in core. I would say the "right" way for now might be just duplicating this config into their own custom
job_config.py, but deeper reason is that we need to reinvent the config system -- the idea is to let each component have its own config, rather than sharing a central config.This seems a bit urgent, as we are hitting such issues from different angles, recently.
cc @ailzhang
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing would happened. Because the graphs saw by compiler are only compute graphs. Then, the bucketing pass would not taking into effect (as no comms are bucketed/reordered).
Maybe we should have a config class specific for pt2-frontier lolll