Skip to content

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Nov 7, 2025

Summary
This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace use_flex_attn with attn_type (either "sdpa", "varlen", "flex"). If attn_type = "varlen", the attention module calls a compiled varlen_attn defined here.

Testing
Ran loss and performance tests against flex attention. Loss is on par.

Screenshot 2025-11-19 at 3 24 26 PM

Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into flash_attention_forward/flash_attention_backward today).

Varlen Flex
Forward 774us 357ns 722us 317ns
Backward 1ms 955us 916ns 1ms 558us 747ns

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 7, 2025
@liangel-02 liangel-02 force-pushed the test_varlen branch 3 times, most recently from eeecb63 to cad97e5 Compare November 12, 2025 22:49
@liangel-02 liangel-02 changed the title Test varlen adding variable length attention to llama 3 8b Nov 12, 2025
@liangel-02 liangel-02 changed the title adding variable length attention to llama 3 8b adding variable length attention to llama3 8b Nov 12, 2025
@liangel-02 liangel-02 requested a review from drisspg November 12, 2025 23:18
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation won't work with PP and too model intrusive. The pack logic should be hide inside the inner attention.

@liangel-02 liangel-02 force-pushed the test_varlen branch 4 times, most recently from 55352a5 to 066ca02 Compare November 14, 2025 18:11
@liangel-02 liangel-02 requested a review from fegin November 14, 2025 18:11
@liangel-02 liangel-02 marked this pull request as ready for review November 14, 2025 18:14
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the update. Leave some other comments, after the comments are addressed, this PR should be ready.

@liangel-02 liangel-02 force-pushed the test_varlen branch 2 times, most recently from a902cbe to de416f9 Compare November 17, 2025 18:05
@liangel-02 liangel-02 requested a review from fegin November 17, 2025 18:05
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments, please see if they make sense to you.

@liangel-02 liangel-02 force-pushed the test_varlen branch 4 times, most recently from caafc81 to 4d36560 Compare November 18, 2025 21:49
@liangel-02 liangel-02 force-pushed the test_varlen branch 2 times, most recently from ca0efc0 to 291daea Compare November 19, 2025 18:27

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
attn_type = getattr(model.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in python 3.11+ strenum seems like a good fit for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchTitan still sticks to 3.10 afaik.

@liangel-02 liangel-02 force-pushed the test_varlen branch 4 times, most recently from 9380847 to 42c0c85 Compare November 19, 2025 22:33
@liangel-02 liangel-02 requested review from fegin and tianyu-l November 19, 2025 22:34
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some more comments. If you'd like to focus on Llama 3 in this PR, that's fine with me too.

extra_kwargs = {}

if getattr(self.model_args, "use_flex_attn", False):
if getattr(self.model_args, "attn_type", "sdpa") == "flex":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"varlen" should also work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc this isn't limited to llama3, ill add varlen after more thorough testing for the other models

match self.attn_type:
case "flex":
self.inner_attention = FlexAttentionWrapper()
case _:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about varlen? also it seems get_attention_masks function in this file is not changed.
If the scope of this PR is to support Llama 3 only, that's fine too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can limit to llama3 in this pr and add support for other models later

@liangel-02 liangel-02 force-pushed the test_varlen branch 4 times, most recently from 5528029 to 31c1c77 Compare November 20, 2025 17:35
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can leave other models to other PR(s).

Comment on lines -214 to -225
if use_flex_attn:
attention_kernel_plan = prepare_module_input(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
use_local_output=True,
)
else:
attention_kernel_plan = prepare_module_input(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
use_local_output=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing duplicated code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya this was there before i removed it per #2000 (comment)

@liangel-02 liangel-02 force-pushed the test_varlen branch 3 times, most recently from 697b9b9 to b717da3 Compare November 20, 2025 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants