-
Notifications
You must be signed in to change notification settings - Fork 609
adding variable length attention to llama3 8b #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
eeecb63 to
cad97e5
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation won't work with PP and too model intrusive. The pack logic should be hide inside the inner attention.
55352a5 to
066ca02
Compare
066ca02 to
c9b6d5c
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the update. Leave some other comments, after the comments are addressed, this PR should be ready.
a902cbe to
de416f9
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left some comments, please see if they make sense to you.
caafc81 to
4d36560
Compare
ca0efc0 to
291daea
Compare
|
|
||
| use_flex_attn = getattr(model.model_args, "use_flex_attn", False) | ||
| if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: | ||
| attn_type = getattr(model.model_args, "attn_type", "sdpa") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: in python 3.11+ strenum seems like a good fit for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TorchTitan still sticks to 3.10 afaik.
9380847 to
42c0c85
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some more comments. If you'd like to focus on Llama 3 in this PR, that's fine with me too.
| extra_kwargs = {} | ||
|
|
||
| if getattr(self.model_args, "use_flex_attn", False): | ||
| if getattr(self.model_args, "attn_type", "sdpa") == "flex": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"varlen" should also work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc this isn't limited to llama3, ill add varlen after more thorough testing for the other models
| match self.attn_type: | ||
| case "flex": | ||
| self.inner_attention = FlexAttentionWrapper() | ||
| case _: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about varlen? also it seems get_attention_masks function in this file is not changed.
If the scope of this PR is to support Llama 3 only, that's fine too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we can limit to llama3 in this pr and add support for other models later
5528029 to
31c1c77
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we can leave other models to other PR(s).
| if use_flex_attn: | ||
| attention_kernel_plan = prepare_module_input( | ||
| input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| desired_input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| use_local_output=True, | ||
| ) | ||
| else: | ||
| attention_kernel_plan = prepare_module_input( | ||
| input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| desired_input_layouts=(Shard(1), Shard(1), Shard(1)), | ||
| use_local_output=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an existing duplicated code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya this was there before i removed it per #2000 (comment)
697b9b9 to
b717da3
Compare
b717da3 to
9c99fcb
Compare
Summary
This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace
use_flex_attnwithattn_type(either "sdpa", "varlen", "flex"). Ifattn_type = "varlen", the attention module calls a compiledvarlen_attndefined here.Testing
Ran loss and performance tests against flex attention. Loss is on par.
Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into
flash_attention_forward/flash_attention_backwardtoday).