-
Notifications
You must be signed in to change notification settings - Fork 611
adding variable length attention to llama3 8b #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bd0c0fc
ea085c1
b3f723d
d8a6254
9381ada
0012170
4fef6eb
0d32d5a
4d80f4e
9c99fcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,9 +67,9 @@ def parallelize_deepseekv3( | |
|
|
||
| if ( | ||
| job_config.parallelism.context_parallel_degree > 1 | ||
| and model.model_args.use_flex_attn | ||
| and model.model_args.attn_type != "sdpa" | ||
| ): | ||
| raise NotImplementedError("CP support for FlexAttention is still in progress.") | ||
| raise NotImplementedError("CP support is only supported for SDPA.") | ||
|
|
||
| if parallel_dims.tp_enabled: | ||
| enable_float8_linear = "float8" in job_config.model.converters | ||
|
|
@@ -85,13 +85,12 @@ def parallelize_deepseekv3( | |
| "Currently, float8 tensorwise TP is not tested for deepseekv3" | ||
| ) | ||
|
|
||
| use_flex_attn = getattr(model.model_args, "use_flex_attn", False) | ||
| attn_type = getattr(model.model_args, "attn_type", "sdpa") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove this line |
||
| apply_non_moe_tp( | ||
| model, | ||
| world_mesh["tp"], | ||
| loss_parallel=not job_config.parallelism.disable_loss_parallel, | ||
| enable_float8_tensorwise_tp=False, | ||
| use_flex_attn=use_flex_attn, | ||
| ) | ||
| maybe_enable_async_tp(job_config, world_mesh["tp"]) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,9 +61,10 @@ def parallelize_deepseekv3( | |
| ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). | ||
| """ | ||
|
|
||
| use_flex_attn = getattr(model.model_args, "use_flex_attn", False) | ||
| if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: | ||
| raise NotImplementedError("CP support for FlexAttention is still in progress.") | ||
| attn_type = getattr(model.model_args, "attn_type", "sdpa") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: in python 3.11+ strenum seems like a good fit for this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TorchTitan still sticks to 3.10 afaik. |
||
| use_flex_attn = attn_type == "flex" | ||
| if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": | ||
| raise NotImplementedError("CP support is only supported for SDPA.") | ||
|
|
||
| if parallel_dims.tp_enabled: | ||
| enable_float8_linear = "float8" in job_config.model.converters | ||
|
|
@@ -84,7 +85,6 @@ def parallelize_deepseekv3( | |
| world_mesh["tp"], | ||
| loss_parallel=not job_config.parallelism.disable_loss_parallel, | ||
| enable_float8_tensorwise_tp=False, | ||
| use_flex_attn=use_flex_attn, | ||
| ) | ||
| maybe_enable_async_tp(job_config, world_mesh["tp"]) | ||
|
|
||
|
|
@@ -181,7 +181,6 @@ def apply_non_moe_tp( | |
| tp_mesh: DeviceMesh, | ||
| loss_parallel: bool, | ||
| enable_float8_tensorwise_tp: bool, | ||
| use_flex_attn: bool, | ||
| ): | ||
| """Apply tensor parallelism.""" | ||
| # 1. Parallelize the embedding and shard its outputs (which are the first | ||
|
|
@@ -211,18 +210,11 @@ def apply_non_moe_tp( | |
| PrepareModuleInput, | ||
| ) | ||
|
|
||
| if use_flex_attn: | ||
| attention_kernel_plan = prepare_module_input( | ||
| input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| desired_input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| use_local_output=True, | ||
| ) | ||
| else: | ||
| attention_kernel_plan = prepare_module_input( | ||
| input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| desired_input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| use_local_output=True, | ||
| ) | ||
|
Comment on lines
-214
to
-225
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an existing duplicated code?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ya this was there before i removed it per #2000 (comment) |
||
| attention_kernel_plan = prepare_module_input( | ||
| input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| desired_input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| use_local_output=True, | ||
| ) | ||
| # Apply tensor + sequence parallelism to every transformer block | ||
| # NOTE: At the cost of model code change, we can accelerate Sequence Parallel | ||
| # by folding (and unfolding) the batch dimension and the sequence dimension. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"varlen"should also work here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc this isn't limited to llama3, ill add varlen after more thorough testing for the other models